• More examples of tail recursions →

    We introduced tail recursion in our previous post. Let’s see some more examples of tail recursion. Let’s do a simple one, very similar to our introductory example in order to illustrate the win that tail-recursion optimization by the compiler provides. Let’s consider the function lst_rev, introduced in one of our earlier posts which takes a list lst and returns the reversed list. Here’s it is for your reference:

    let rec lst_rev lst =
      match lst with
      | [] -> []
      | h::t ->
         (lst_rev t) @ [h];;

    Recollect that in Ocaml, @ is the operator for list concatenation. Let’s work out how the execution of this function on our list [1; 2; 3 4] would look like:

    	lst_rev [1; 2; 3; 4]
    = 	lst_rev [2; 3; 4] @ [1]
    = 	lst_rev [3; 4] @ [2] @ [1]
    = 	lst_rev [4] @ [3] @ [2] @ [1]
    = 	lst_rev [] @ [4] @ [3] @ [2] @ [1]
    =	[] @ [4] @ [3] @ [2] @ [1]
    = 	[4; 3; 2; 1]
    

    Here is our tail-recursive version of the same. Let’s call it tr_lst_rev.

    let tr_lst_rev lst =
      let rec aux_lst_rev lst acc =
        match lst with
        | [] -> acc
        | h::t -> aux_lst_rev t (h::acc)
      in
      aux_lst_rev lst [];;

    Again, very similar to our previous example, we use an accumulator acc to collect the results of intermediate computations (results of each recursive step).

    	tr_lst_rev [1; 2; 3; 4]
    =	aux_lst_rev [2; 3; 4]	1::[]
    =	aux_lst_rev [3; 4] 	2::[1]
    =	aux_lst_rev [4] 	3::[2; 1]
    =	aux_lst_rev [] 		4::[3; 2; 1]
    = 	[4; 3; 2; 1]
    
    
  • Introducing tail recursions →

    Now that we are already thinking of recursive function calls whenever possible, it is a nice time to introduce tail recursions. Let’s consider this function lst_sum that we worked out in one of our earlier posts, repeated here for convenience:

    let rec lst_sum lst = 
      match lst with
      | [] -> 0
      | h::t -> h + lst_sum t;;

    When we call this function on this list l1 = [1; 2; 3; 4; 5], this is what would how the stack frame of the execution by the compiler would look like:

    	lst_sum [1; 2; 3; 4; 5]
    =	1 + lst_sum [2; 3; 4; 5]
    =	1 + 2 + lst_sum [3; 4; 5]
    =	1 + 2 + 3 + lst_sum [4; 5]
    =	1 + 2 + 3 + 4 + lst_sum [5]
    =	1 + 2 + 3 + 4 + 5 + lst_sum []
    =	1 + 2 + 3 + 4 + 5 + 0
    =	15
    
    

    Note that it performs the actual computation (or evaluates the intermediate results) only after all the recursive calls are completed. This means that we need to keep the function calls in the stack (which means spend storage for the function’s local variables, etc.) until all the recursive calls are returned. This is a lot of resources wasted.

    Let’s see a tail-recursive version of the same function. Let’s call it tr_lst_sum.

    let tr_lst_sum lst =
      let rec aux_lst_sum lst acc =
        match lst with
        | [] -> acc
        | h::t -> aux_lst_sum t (acc + h)
      in
      aux_lst_sum lst 0;;

    We have an auxiliary recursive function aux_lst_sum and pass another parameter acc (for accumulator) to it. The idea is to accumulate intermediate results so that we can start computations as and when a step completes, and not have to save the recursive calls to the functions in the stack frame. This is how the execution of our tail-recursive version would look like:

    	tr_lst_sum [1; 2; 3; 4; 5]
    = 	aux_lst_sum [1; 2; 3; 4; 5] 0
    = 	aux_lst_sum [2; 3; 4; 5]    1
    = 	aux_lst_sum [3; 4; 5] 	    3
    = 	aux_lst_sum [4; 5] 	    6
    = 	aux_lst_sum [5] 	    10
    = 	aux_lst_sum [] 		    15
    =	15
    

    Note that the accumulator acc computes the intermediate results as and when they are available. We also store only one instance of the function call in the stack thereby saving a lot of space and more importantly, avoiding stack overflow.

    Hope you enjoyed this!

  • Computing factors of a number →

    Let’s look at a quick algorithm to compute all factors of a number N. For example, if N is 12 we want all the factors [1, 2, 3, 4, 6, 12]. If N is say 36 we want [1, 2, 3, 4, 6, 9, 12, 18, 36]. We observe that:

    1. 1 and N itself are always factors, of course.
    2. Factors always occur in pairs. (1, 12), (2, 6), (3, 4) etc. If N is a perfect square, we also have (\sqrt(N), \sqrt(N)).

    From the above we realize that we don’t need to iterate through the numbers all the way until N. We just need to do it until \sqrt(N). Here’s the algorithm for computing all the factors:

    vector<int> all_factors (int N) {
        std::vector<int> result;
        result.push_back(1);
        result.push_back(N);
        for (int i = 2; i < (int)sqrt(N); i++) {
            if ((N % i) == 0) {
                result.push_back(i);
                if (i != (int)sqrt(N))
                    result.push_back(N/i);
            }
        }
        return result;
    }

    That was a super short post!

  • Computing maximum subarray sum →

    Let’s look at one of my favourite algorithm problems from my younger days - Computing the maximum subarray sum of a given array. That is, if you are given an array, find the maximum sum formed by the subarrays of the array. For example, if the array is [-1; 2; 6; 4; 2], the maximum subarray sum is 12 contributed by the subarray [2; 6; 4]. Let’s look at how to solve this.

    The straight-forward approach is to go through all possible subarrays, compute their sum and pick the maximum of those.

    int max_subarray_sum_On3 (vector<int> arr) {
        int res = 0;
        int n = arr.size();
        for (int a = 0; a < n; a++) {
            for (int b = a; b < n; b++) {
                int sum = 0;
                for (int c = a; c <= b; c++) {
                    sum += arr[c];
                }
                res = max(res, sum);
            }
        }
        return res;
    }

    Here, a and b denote the window of the subarray. We use the loop with the index c to compute the sum in the subarray defined by the window between a and b. We then compute the max of the accumulated sum sum from this window and the saved result res. The time complexity of this algorithm is O(n3).

    Let’s do it a little better. What if we compute the sum at the same time as we extend the window to the right - i.e., move b to the right (the second for loop that increments b)? Let’s try that.

    int max_subarray_sum_On2 (vector<int> arr) {
        int res = 0;
        int n = arr.size();
        for (int a = 0; a < n; a++) {
            int sum = 0;
            for (int b = a; b < n; b++) {
                sum += arr[b];
                res = max(res, sum);
            }
        }
        return res;
    }

    We initialize sum before entering the loop and compute the sum and the max of sum and the saved result (max-so-far) within this loop. The time complexity of this algorithm is O(n2).

    Can we do better? I initially thought ‘No’. And it was too hard to convince myself the simple and elegant algorithm by Joseph Kadene.

    int max_subarray_sum_3 (vector<int> arr) {
        int res = 0, sum = 0;
        int n = arr.size();
        for (int a = 0; a < n; a++) {
            sum = max(arr[a], sum + arr[a]);
            res = max(res, sum);
        }
        return res;
    }

    The idea is to look at the first loop (indexed by a) as a subarray consisting of elements upto a-1 followed by element at a. As we traverse the array, we keep computing the max of the element we are currently at arr[a] and (the sum-so-far + the element we are currently at - which is actually the boundary of the subarray, meaning max-until-this-point) - sum + arr[a]. res indicates the maximum value seen so far and we update it after comparing against the max-until-this-point (the max we computed after moving the window one element to the right, which is now in sum). Since we iterate over the array only once (one for loop), the time complexity of this algorithm is O(n). Isn’t it awesome?

    I don’t know how clearly I managed to explain it. Someday, I will add intuitive images working out the above three and see if it is any more intuitive.

  • Learning to count set bits →

    Let’s take a break from recursion and learn to count the number of bits set in an integer. For example, in an integer say 9 there are two set bits. Of course, we will be dealing with unsigned integers and the number of bits set is in the binary representation (the language that our computers understand and speak). We already saw how to convert an integer to binary in one of our previous posts. You may want to refer to that in case you want to quickly refresh your memory.

    Let’s work out the first technique we would use. Let’s take the integer 9. It’s binary representation is 1001. So, there are two set bits. Looks very intuitive - start from the right-most bit, have a counter and start counting the set bits, while right shifting our bit string. We just need a counter to keep track of the count and two operations - right shift and an operation to identify a set bit. Let’s see how to do that.

    1. We initialize a counter, call it count to 0.
    2. We do a bitwise & with 1. This would identify if a bit is set. If it’s set, we would increment count.
    3. Right shift our original input and continue till we encounter 0 (terminating condition for our loop).
       n  	       n&1 	    	  count
       1001	       			  0
                   0001		  1
       0100	       0000		  1
       0010	       0000		  1
       0001	       0001		  2
       0000	        -		  2
    

    final result = 2

    Let’s write a quick C function:

    unsigned int count_set_bits (unsigned int n) 
    {
        unsigned int count = 0;
        while (n) {
            count += n & 1;
            n >>= 1;
        }
        return count;
    }

    There’s a better way to do this, thanks to Brian Kernighan’s algorithm. He observed that subtracting a 1 toggles all the bits upto and including the right-most set bit. For example, if you subtract 1 from 1010, it results in 1001 - the last two bits, 0 and the right-most set bit 1 got flipped. Now, if we perform a bitwise & of this result with the original number, effectively we would be unsetting the right-most set bit. Let’s work out an example to understand it better. Let’s start again with 1001 (9).

      n       n-1    n&(n-1)      count
    1001     1000     1000          1
    1000     0111     0000          2
    0000      --       --           2
    
    

    final result = 2. Note that we loop only as many times as the number of set bits, unlike the previous case.

    Let’s write a quick C function:

    unsigned int BK_count_set_bits (unsigned int n) 
    {
        unsigned int count = 0;
        while (n) {
            n &= (n-1);
            count++;
        }
        return count;
    }

    That’s all for now. Let me get my head around other bits and pieces before my next post.