using Programming;

A Blog about some of the intrinsics related to programming and how one can get the best out of various languages.

Optimizing for Tail-Call Recursion in F#

Rewriting for Tail-Call Recursion in F#

While this post is written in F# and specifically for it (referencing ILASM), the principles and practices here can be applied to any function (or non-functional) language that uses tail-call recursion and provides optimizations for it.

Recently I was working on a project, and I had to extend the F# List type to make things a bit simpler. I wrote a method takeThrough that allowed me to provide a predicate, and much like takeWhile, it would return elements until the predicate returned false, then it would return one more element. This was important for the code I was writing, I needed to return everything up to and including the first element that caused the predicate to return false.

So, I wrote this method called takeThrough:

let takeThrough(predicate)(source) =
    let rec loop sourceTemp =
        let head = sourceTemp |> List.head
        if head |> predicate = true then
            head :: (sourceTemp |> List.tail |> loop)
        else
            [head]
    loop source

The problem with this method is that it cannot be optimized for tail-call recursion in it's current state.

What is Tail-Call Recursion?

In order to understand why we're talking about optimizing for tail-call recursion here, we have to first undestand what is tail-call recursion?

In functional languages such as F# things like loops are discouraged, and in some of them even unavailable completely. So, in order to avoid them we have to rewrite our methods for recursion of some form (usually).

If we were to take this same example in C# it might look something like:

IEnumerable<T> TakeThrough<T>(IEnumerable<T> source, Predicate<T> predicate)
{
    var continueLoop = true;

    foreach (var item in source)
    {
        if (predicate(item))
        {
            yield return item;
            continueLoop = true;
        }
        else
        {
            yield return item;
            continueLoop = false;
        }

        if (!continueLoop)
        {
            break;
        }
    }
}

(This may not be the most optimal manner to write this method in, but it guarantees success.)

So we're just looping through each item here and returning it as we go. With F# we don't want to use such a construct, as we want to avoid loops. So we go to recursion, in C# the F# code we wrote above might look more like:

IEnumerable<T> TakeThrough<T>(IEnumerable<T> source, Predicate<T> predicate)
{
    return _loop(source, predicate);
}

IEnumerable<T> _loop<T>(IEnumerable<T> sourceTemp, Predicate<T> predicate)
{
    var head = sourceTemp.First();

    if (predicate(head))
    {
        var result = new List<T>();
        result.Add(head);
        result.AddRange(_loop(sourceTemp.Skip(1), predicate));
        return result;
    }
    else
    {
        return new List<T> { head };
    }
}

It's almost verbatim identical to the F# version. What we can see being a problem here is a StackOverflowException being through if source is large enough and predicate would return late enough. This is what we're hoping to avoid with Tail-Call recursion.

Remember: in order for a method to be optimized for tail-call recursion, the recursive call has to be the last thing the method does.

Now you might look at that method and say "well the last thing that happens is everything is piped to loop." Not quite true. We don't realize that head :: is the very last thing the method has to do.

This is an important note because loop is called, then that value is given to the concatenation operator.

The if/else is ugly too

Of course the other problem is the if/else construct, but that can be fixed with a match head |> predicate with and then match to each boolean value (true and false).

Right, so that's simple. Easy fix:

match head |> predicate with
| true -> head :: (sourceTemp |> List.tail |> loop)
| false -> [head]

Great. We solved the easy idiomatic issue, but how in the world do we make it tail-call recursive?

Visualizing our tail-call recursion

The first thing we have to do is determine how can we write the structure of this method so that the loop call is the last thing to happen? Ignore what it does for now. We just want to know what it would have to look like. We need a visual.

let takeThrough predicate list =
    let rec loop ... =
        ...
        loop ...
    loop ...

So we know what it should look like-ish. That's a very good start. Now we have to figure out how we can get it to that state.

So we know that our method needs to match each item with a predicate, and then return a List of all the elements that matched and the next element. So we need to accumulate a list of elements.

Notice I bolded accumulate. We need a variable in our loop that is an accumulator in this case.

Now we know our visual needs to change:

let takeThrough predicate list =
    let rec loop acc ... =
        ...
        loop newAcc ...
    loop [] ...

This looks about right. Our acc will be a List since that's what we're building out of, and we're going to pipe the newAcc to the list each time we iterate, and then pipe an empty list to our loop before we get started.

Creating our tail-call recursion

So now that we've visualized it, we can start to write the final pieces of it.

We'll start at the final line: loop [] .... What do we know about this call? We know that we only have two parameters in the method, and one variable (well, constant, but it's a function so it will look like a variable). And that's all we need. So we'll pass our initial list to our loop because it's the only we have to pass.

let takeThrough predicate list =
    let rec loop acc ... =
        ...
        loop newAcc ...
    loop [] list

Now our definition for loop has to change:

let takeThrough predicate list =
    let rec loop acc listTemp =
        ...
        loop newAcc newListTemp
    loop [] list

Alright, great progress. We just have to apply our operations now. In our case, the newAcc will be the appended list, and the listTemp will be stripped of the first item. Let's get the logic for head in there and work from that.

let takeThrough predicate list =
    let rec loop acc listTemp =
        let head = listTemp |> List.head
        match head |> predicate with
        | true -> ... loop newAcc newList
        | false -> ...
    loop [] list

Perfect! We're almost done, getting newAcc and newList are both easy: newAcc is just List.append acc [head], and newList is just listTemp |> List.tail.

let takeThrough predicate list =
    let rec loop acc listTemp =
        let head = listTemp |> List.head
        match head |> predicate with
        | true -> loop (List.append acc [head]) (sourceTemp |> List.tail)
        | false -> ...
    loop [] list

The last issue is our false condition: what do we do here?

Simple: we just kill the batman return what the newAcc would have been.

let takeThrough predicate list =
    let rec loop acc listTemp =
        let head = listTemp |> List.head
        match head |> predicate with
        | true -> loop (List.append acc [head]) (sourceTemp |> List.tail)
        | false -> List.append acc [head]
   loop [] list

And we've achieved our goal of tail-call recursion. The very last thing in loop is a call to itself. (Remember that in this case, match is the last condition, then one of two things happens: we call loop or we return the new stuff.)

Verifying with ILDASM

If we look at the IL for this method, we'll see the following:

.method assembly static class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T> 
        loop@4<T>(class [FSharp.Core]Microsoft.FSharp.Core.FSharpFunc`2<!!T,bool> predicate,
                  class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T> acc,
                  class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T> sourceTemp) cil managed
{
  // Code size       69 (0x45)
  .maxstack  6
  .locals init ([0] !!T head)
  IL_0000:  nop
  IL_0001:  ldarg.2
  IL_0002:  call       !!0 [FSharp.Core]Microsoft.FSharp.Collections.ListModule::Head<!!0>(class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0>)
  IL_0007:  stloc.0
  IL_0008:  ldarg.0
  IL_0009:  ldloc.0
  IL_000a:  callvirt   instance !1 class [FSharp.Core]Microsoft.FSharp.Core.FSharpFunc`2<!!T,bool>::Invoke(!0)
  IL_000f:  brfalse.s  IL_0031
  IL_0011:  ldarg.0
  IL_0012:  ldarg.1
  IL_0013:  ldloc.0
  IL_0014:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!0> class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T>::get_Empty()
  IL_0019:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!0> class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T>::Cons(!0,
                                                                                                                                                                class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!0>)
  IL_001e:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0> [FSharp.Core]Microsoft.FSharp.Core.Operators::op_Append<!!0>(class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0>,
                                                                                                                                                      class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0>)
  IL_0023:  ldarg.2
  IL_0024:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0> [FSharp.Core]Microsoft.FSharp.Collections.ListModule::Tail<!!0>(class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0>)
  IL_0029:  starg.s    sourceTemp
  IL_002b:  starg.s    acc
  IL_002d:  starg.s    predicate
  IL_002f:  br.s       IL_0000
  IL_0031:  ldarg.1
  IL_0032:  ldloc.0
  IL_0033:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!0> class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T>::get_Empty()
  IL_0038:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!0> class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!T>::Cons(!0,
                                                                                                                                                                class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!0>)
  IL_003d:  tail.
  IL_003f:  call       class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0> [FSharp.Core]Microsoft.FSharp.Core.Operators::op_Append<!!0>(class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0>,
                                                                                                                                                      class [FSharp.Core]Microsoft.FSharp.Collections.FSharpList`1<!!0>)
  IL_0044:  ret
} // end of method List::loop@4

We're only concerned with our recursion:

IL_0000:  nop
...
IL_000f:  brfalse.s  IL_0031
...
IL_002f:  br.s       IL_0000
IL_0031:  ldarg.1
...
IL_003d:  tail.

Just as we hoped, it's simply returning to the beginning of the method instead of calling it again.

Functionally replacing if in F#

Eliminating if from F#

Recently I was asked by a colleague if there were a better way to write a specific method this colleague was using. It was a simple method which called a couple other methods and returned a value from them. Essentially, if two conditions met specific criteria, call one of four other methods. Oh, and it was in F#.

Naturally, as C-style programmers it's easy for us to use if or switch to do what we want, but for some reason when we look at functional languages we cannot seem to reason how we should replace these two constructs with match. It must be trivial, right? We must be missing some silly detail. That's not entirely true. We're not missing anything trivial, we're just not being creative enough.

Functional languages like F# bear the advantage of being very verbose about what's going on. They're also great at implicitly typing things, and making a function read as a mathematical expression. I bolded that for a reason: if we begin to look at our code as a mathematical expression instead of code, we will hopefully see what we're missing.

Let's look at a much reduced sample of the code we were working with:

type ThingType = 
    Left = 0
    | Right = 1

member private this.methodLeftOne =
    true

member private this.methodRightOne =
    false

member private this.methodLeftTwo =
    false

member private this.methodRightTwo =
    true

member this.MatchAndIf var thingType =
    match var with
    | 1 -> if thingType = ThingType.Left then this.methodLeftOne else this.methodRightOne
    | 2 -> if thingType = ThingType.Left then this.methodLeftTwo else this.methodRightTwo
    | _ -> false

My colleague was calling the MatchAndIf function which was to return a boolean value from the two parameters. The code in the other four methods was a bit more complex, but I've simplified it here so we can see how things will turn out.

So, we're looking at a pretty simple bit of code: if var and thingType are 1 and ThingType.Left respectively, return methodLeftOne, if they're 1 and any other ThingType value, return this.methodRightOne, etc. Pretty easy to follow.

We have a slight inconsistency here, however. If thingType is set to a non-valid value, then unexpected (well, unintended) things can happen. This is not so ideal. To fix it with this code would be a mess, now we would have if ... then ... else if ... then ... else .... Sure, that does what we want, but it's really ugly for F#.

Nested match

So, the first thing we might think of to rewrite it is to use a nested match. Alright, easy enough, replace the inner if with a match.

member this.NestedMatch var thingType =
    match var with
    | 1 ->
        match thingType with
        | ThingType.Left -> this.methodLeftOne
        | _ -> this.methodRightOne
    | 2 ->
        match thingType with
        | ThingType.Left -> this.methodLeftTwo
        | _ -> this.methodRightTwo
    | _ -> false

This is obviously more F#-like. It gives us a lot more peace-of-mind, right? But we didn't fix the issue above, so let's do that.

member this.NestedMatchFixed var thingType =
    match var with
    | 1 ->
        match thingType with
        | ThingType.Left -> this.methodLeftOne
        | ThingType.Right -> this.methodRightOne
        | _ -> false
    | 2 ->
        match thingType with
        | ThingType.Left -> this.methodLeftTwo
        | ThingType.Right -> this.methodRightTwo
        | _ -> false
    | _ -> false

Wait a minute, why do we need three default (_) cases? Ah, right, because if 1 or 2 are matched, they won't fall through to the default case, and F# will get very upset if we omit it and implicitly return false. (That's not always a bad thing.)

Tuple match

Well, we might think to ourselves "I can just match on a Tuple instead." Indeed that's true, let's see how that looks.

member this.TupleMatch var thingType =
    match (var, thingType) with
    | (1, ThingType.Left) -> this.methodLeftOne
    | (1, ThingType.Right) -> this.methodRightOne
    | (2, ThingType.Left) -> this.methodLeftTwo
    | (2, ThingType.Right) -> this.methodRightTwo
    | _ -> false

Alright, that's not bad. We've gotten a lot closer to our goal. But now we have things knowing about things they shouldn't. The TupleMatch method does too many things inside it. It's looking for a var of 1 or 2 and a ThingType.

Finally Isolating Everything

The only other thing we can do to fix this (which I can tell you is the best option based on the context of what code I had) is to check thingType in our Match method, and pipe var to our methodLeft or methodRight method (whichever is appropriate).

member private this.methodLeft var =
    match var with
    | 1 -> true
    | _ -> false

member private this.methodRight var =
    match var with
    | 2 -> true
    | _ -> false

member this.FinalMatch var thingType =
    match thingType with
    | ThingType.Left -> var |> this.methodLeft
    | ThingType.Right -> var |> this.methodRight
    | _ -> false

Now each method is only responsible for checking and reporting the parts it cares about. We complied with SRP and we kept it entirely functional. Each method is responsible for looking at only the code it cares about, it's not worried about what the next method down the chain is doing.