Function memoization is an optimization technique to avoid repeated calculation of function values which have been calculated by a previous evaluation of the function. In this post I show how function memoization can be implemented in Scala. Although straight forward at a first glance, effectively memoizing recursive functions requires some second thoughts. For the sake of simplicity I only discuss functions of arity one.

Assume we have a function *strSqLen* which calculates the square of the length of a string.

def strSqLen(s: String) = s.length*s.length

Assume further, that for some reason evaluating above function is processor intensive. One way to speed things up is to cache result values and to look them up on subsequent invocations of *strSqLen*. This can be done either by the function itself or by the caller. The first approach has the drawback that each and every function has to implement caching and that clients have no control over the caching mechanism. The latter approach puts all the burden on the client programmer and produces potential boilerplate. To overcome these issues we need a way to construct a memoized function having the same type as a given function.

In Scala a function of arity one is an instance of *Function1*. We can thus sub-class a *Function[-T, +R]* to *Memoize1[-T, +R]* which represents the memoized function. (Note the concise syntactic sugar *T => R* which is synonym to the more verbose *Function1[T, R]*.)

class Memoize1[-T, +R](f: T => R) extends (T => R) { import scala.collection.mutable private[this] val vals = mutable.Map.empty[T, R] def apply(x: T): R = { if (vals.contains(x)) { vals(x) } else { val y = f(x) vals += ((x, y)) y } } } object Memoize1 { def apply[T, R](f: T => R) = new Memoize1(f) }

When applied to an argument of type *T*, *apply* checks whether the function value is in the cache. If so, it returns that value. If not, it calls the original function *f*, puts the function value into the cache and returns it. Note that the member *vals* which contains the cached function values has object private visibility (*private[this]*). A less restrictive visibility would cause type checking to fail since *mutable.Map[A, B]* is invariant in both its type arguments while *Memoize1[-T, +R]* is contravariant in *T* and covariant in *R* (as is *Function1[-T, +R]*). However since *vals* is accessed from its containing instance only, it cannot cause problems with variance. We tell this to the compiler by declaring the field object private.

We can now easily create and use a memoization of *strSqLen* like this:

val strSqLenMemoized = Memoize1(strSqLen) val a = strSqLenMemoized("hello Memo") val b = strSqLen("hello Memo") assert(a == b)

### Going recursive

Memoization of recursive functions is possible but might not expose the desired effect. Consider the following recursive implementation of the factorial function.

def fac(n: BigInt): BigInt = { if (n == 0) 1 else n*fac(n - 1) }

Calculating the factorials from 200 down to 0 does not take advantage of the memoization at all.

val facMem = Memoize1(fac) for (k <- 200 to 0 by -1) { println(facMem(k)) }

While *facMem(200)* caches the factorial of 200, it does not cache the results of its intermediate recursive invocations since the recursive calls invoke the original function instead of the memoized. When it comes to calculating *facMem(199)* there is no gain from memoization here since – although calculated before – *facMem(199)* is not cached.

To improve this we need a more flexible implementation of *fac*. We want its recursive calls to be available on the outside such that they are available for caching.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = { if (n == 0) 1 else n*f(n - 1) }

Here the caller needs to pass a function for calculating factorials to *factRec*. She is therefore free to pass a memoized function to speed up recurrent invocations. However, it seems we are back to square one: we need to pass a function for calculating factorials do *facRec* in order for it to calculate factorials. As it turns out this is not the case. By passing *facRec* to itself recursively, we can construct the desired factorial function.

var fac: BigInt => BigInt = null fac = facRec(_, fac(_))

First we declare function *fac* to map from *BigInt* to *BigInt*. Then we partially apply *facRec* to *fac* which yields the desired factorial function.

We can generalize and factor out this construction process into an object *Y* like this:

object Y { def apply[T, R](f: (T, T => R) => R): (T => R) = { var yf: T => R = null yf = f(_, yf(_)) yf } }

Along the same lines we can provide functionality for creating memoized versions of recursive functions. Instead of just recursively invoking the passed function, we pass a memoized version of it.

object Memoize1 { // ... same as above def Y[T, R](f: (T, T => R) => R) = { var yf: T => R = null yf = Memoize1(f(_, yf(_))) yf } }

Using *Memoize1.Y* we can calculate the factorials from 200 down to 0 while taking full advantage of memoization of all intermediate recursive invocation.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = { if (n == 0) 1 else n*f(n - 1) } val fac = Memoize1.Y(facRec) for (k <- 200 to 0 by -1) println(fac(k))