Function Memoization in Scala

23 02 2009

Function memoization is an optimization technique to avoid repeated calculation of function values which have been calculated by a previous evaluation of the function. In this post I show how function memoization can be implemented in Scala. Although straight forward at a first glance, effectively memoizing recursive functions requires some second thoughts. For the sake of simplicity I only discuss functions of arity one.

Assume we have a function strSqLen which calculates the square of the length of a string.

def strSqLen(s: String) = s.length*s.length

Assume further, that for some reason evaluating above function is processor intensive. One way to speed things up is to cache result values and to look them up on subsequent invocations of strSqLen. This can be done either by the function itself or by the caller. The first approach has the drawback that each and every function has to implement caching and that clients have no control over the caching mechanism. The latter approach puts all the burden on the client programmer and produces potential boilerplate. To overcome these issues we need a way to construct a memoized function having the same type as a given function.

In Scala a function of arity one is an instance of Function1. We can thus sub-class a Function[-T, +R] to Memoize1[-T, +R] which represents the memoized function. (Note the concise syntactic sugar T => R which is synonym to the more verbose Function1[T, R].)

class Memoize1[-T, +R](f: T => R) extends (T => R) {
  import scala.collection.mutable
  private[this] val vals = mutable.Map.empty[T, R]

  def apply(x: T): R = {
    if (vals.contains(x)) {
      vals(x)
    }
    else {
      val y = f(x)
      vals += ((x, y))
      y
    }
  }
}

object Memoize1 {
  def apply[T, R](f: T => R) = new Memoize1(f)
}

When applied to an argument of type T, apply checks whether the function value is in the cache. If so, it returns that value. If not, it calls the original function f, puts the function value into the cache and returns it. Note that the member vals which contains the cached function values has object private visibility (private[this]). A less restrictive visibility would cause type checking to fail since mutable.Map[A, B] is invariant in both its type arguments while Memoize1[-T, +R] is contravariant in T and covariant in R (as is Function1[-T, +R]). However since vals is accessed from its containing instance only, it cannot cause problems with variance. We tell this to the compiler by declaring the field object private.

We can now easily create and use a memoization of strSqLen like this:

val strSqLenMemoized = Memoize1(strSqLen)
val a = strSqLenMemoized("hello Memo")
val b = strSqLen("hello Memo")
assert(a == b)

Going recursive

Memoization of recursive functions is possible but might not expose the desired effect. Consider the following recursive implementation of the factorial function.

def fac(n: BigInt): BigInt = {
  if (n == 0) 1
  else n*fac(n - 1) 
}

Calculating the factorials from 200 down to 0 does not take advantage of the memoization at all.

val facMem = Memoize1(fac)
    
for (k <- 200 to 0 by -1) { 
  println(facMem(k))
}

While facMem(200) caches the factorial of 200, it does not cache the results of its intermediate recursive invocations since the recursive calls invoke the original function instead of the memoized. When it comes to calculating facMem(199) there is no gain from memoization here since – although calculated before – facMem(199) is not cached.

To improve this we need a more flexible implementation of fac. We want its recursive calls to be available on the outside such that they are available for caching.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
  if (n == 0) 1
  else n*f(n - 1) 
}

Here the caller needs to pass a function for calculating factorials to factRec. She is therefore free to pass a memoized function to speed up recurrent invocations. However, it seems we are back to square one: we need to pass a function for calculating factorials do facRec in order for it to calculate factorials. As it turns out this is not the case. By passing facRec to itself recursively, we can construct the desired factorial function.

var fac: BigInt => BigInt = null 
fac = facRec(_, fac(_))

First we declare function fac to map from BigInt to BigInt. Then we partially apply facRec to fac which yields the desired factorial function.

We can generalize and factor out this construction process into an object Y like this:

object Y {
  def apply[T, R](f: (T, T => R) => R): (T => R) = {
    var yf: T => R = null
    yf = f(_, yf(_))
    yf
  }
}

Along the same lines we can provide functionality for creating memoized versions of recursive functions. Instead of just recursively invoking the passed function, we pass a memoized version of it.

object Memoize1 {
  // ... same as above

  def Y[T, R](f: (T, T => R) => R) = {
    var yf: T => R = null
    yf = Memoize1(f(_, yf(_)))
    yf
  }
}

Using Memoize1.Y we can calculate the factorials from 200 down to 0 while taking full advantage of memoization of all intermediate recursive invocation.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
  if (n == 0) 1
  else n*f(n - 1) 
}
    
val fac = Memoize1.Y(facRec)
    
for (k &lt;- 200 to 0 by -1) 
  println(fac(k))