Function memoization is an optimization technique to avoid repeated calculation of function values which have been calculated by a previous evaluation of the function. In this post I show how function memoization can be implemented in Scala. Although straight forward at a first glance, effectively memoizing recursive functions requires some second thoughts. For the sake of simplicity I only discuss functions of arity one.

Assume we have a function *strSqLen* which calculates the square of the length of a string.

def strSqLen(s: String) = s.length*s.length

Assume further, that for some reason evaluating above function is processor intensive. One way to speed things up is to cache result values and to look them up on subsequent invocations of *strSqLen*. This can be done either by the function itself or by the caller. The first approach has the drawback that each and every function has to implement caching and that clients have no control over the caching mechanism. The latter approach puts all the burden on the client programmer and produces potential boilerplate. To overcome these issues we need a way to construct a memoized function having the same type as a given function.

In Scala a function of arity one is an instance of *Function1*. We can thus sub-class a *Function[-T, +R]* to *Memoize1[-T, +R]* which represents the memoized function. (Note the concise syntactic sugar *T => R* which is synonym to the more verbose *Function1[T, R]*.)

class Memoize1[-T, +R](f: T => R) extends (T => R) { import scala.collection.mutable private[this] val vals = mutable.Map.empty[T, R] def apply(x: T): R = { if (vals.contains(x)) { vals(x) } else { val y = f(x) vals += ((x, y)) y } } } object Memoize1 { def apply[T, R](f: T => R) = new Memoize1(f) }

When applied to an argument of type *T*, *apply* checks whether the function value is in the cache. If so, it returns that value. If not, it calls the original function *f*, puts the function value into the cache and returns it. Note that the member *vals* which contains the cached function values has object private visibility (*private[this]*). A less restrictive visibility would cause type checking to fail since *mutable.Map[A, B]* is invariant in both its type arguments while *Memoize1[-T, +R]* is contravariant in *T* and covariant in *R* (as is *Function1[-T, +R]*). However since *vals* is accessed from its containing instance only, it cannot cause problems with variance. We tell this to the compiler by declaring the field object private.

We can now easily create and use a memoization of *strSqLen* like this:

val strSqLenMemoized = Memoize1(strSqLen) val a = strSqLenMemoized("hello Memo") val b = strSqLen("hello Memo") assert(a == b)

### Going recursive

Memoization of recursive functions is possible but might not expose the desired effect. Consider the following recursive implementation of the factorial function.

def fac(n: BigInt): BigInt = { if (n == 0) 1 else n*fac(n - 1) }

Calculating the factorials from 200 down to 0 does not take advantage of the memoization at all.

val facMem = Memoize1(fac) for (k <- 200 to 0 by -1) { println(facMem(k)) }

While *facMem(200)* caches the factorial of 200, it does not cache the results of its intermediate recursive invocations since the recursive calls invoke the original function instead of the memoized. When it comes to calculating *facMem(199)* there is no gain from memoization here since – although calculated before – *facMem(199)* is not cached.

To improve this we need a more flexible implementation of *fac*. We want its recursive calls to be available on the outside such that they are available for caching.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = { if (n == 0) 1 else n*f(n - 1) }

Here the caller needs to pass a function for calculating factorials to *factRec*. She is therefore free to pass a memoized function to speed up recurrent invocations. However, it seems we are back to square one: we need to pass a function for calculating factorials do *facRec* in order for it to calculate factorials. As it turns out this is not the case. By passing *facRec* to itself recursively, we can construct the desired factorial function.

var fac: BigInt => BigInt = null fac = facRec(_, fac(_))

First we declare function *fac* to map from *BigInt* to *BigInt*. Then we partially apply *facRec* to *fac* which yields the desired factorial function.

We can generalize and factor out this construction process into an object *Y* like this:

object Y { def apply[T, R](f: (T, T => R) => R): (T => R) = { var yf: T => R = null yf = f(_, yf(_)) yf } }

Along the same lines we can provide functionality for creating memoized versions of recursive functions. Instead of just recursively invoking the passed function, we pass a memoized version of it.

object Memoize1 { // ... same as above def Y[T, R](f: (T, T => R) => R) = { var yf: T => R = null yf = Memoize1(f(_, yf(_))) yf } }

Using *Memoize1.Y* we can calculate the factorials from 200 down to 0 while taking full advantage of memoization of all intermediate recursive invocation.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = { if (n == 0) 1 else n*f(n - 1) } val fac = Memoize1.Y(facRec) for (k <- 200 to 0 by -1) println(fac(k))

HRJ(03:13:48) :Thanks for the post! Well written.

I feel this should be made available in the std libs.

Daniel Spiewak(03:34:06) :I tend to prefer a version of the Y-combinator which requires a curried function as its parameter. I think this leads to a slightly cleaner syntax at times:

def apply[T, R](f: (T => R) => (T) => R): (T => R) = {

var yf: T => R = null

yf = f(_, yf(_))

yf

}

def facRec(f: BigInt => BigInt)(n: BigInt): BigInt = {

if (n == 0) 1

else n*f(n – 1)

}

Now we can be arguably more clean (though, less explicit) in our non-memoized fixpoint for facRec:

val fac: BigInt=>BigInt = facRec(fac(_))

In this, the fixpoint is implicit in Scala’s definition of a value, obviating the need to define it separately in the Y object.

michid(10:18:56) :Daniel,

Tanks for the feedback. Using curried functions is cleaner indeed. I should have thought of that ;-)

The less explicit version of the fixpoint gives me some troubles.

`val fac: BigInt => BigInt = facRec(fac)(_)`

results in the Compiler complaining

forward reference extends over definition of value fac.Using def instead of val works however:

`def fac: BigInt => BigInt = facRec(fac)(_)`

This helps getting rid of the var Y:

object Y {

def apply[T, R](f: (T => R) => T => R): (T => R) = {

def yf: T => R = f(yf)(_)

yf

}

}

Aaron Novstrup(19:27:28) :Note that using `def yf` doesn’t work in the memoized version (it results in a new memoized function being generated for each recursive call, and therefore effectively no memoization). `lazy val yf` will work though

lcn(21:34:53) :Per http://mvanier.livejournal.com/2897.html, the impure `Y` implemented here “will only work in a lazy language”, which might explain why `lazy val` is needed here.

lcn(04:21:19) :Ignore my comment, I made a mistake about `Y` and `yf`. Though `yf` is cheating (see discussion at http://goo.gl/TJMO2h), `Y` doesn’t have any problem.

asyropoulos(14:14:50) :I would like to mention your work on memo functions in a forthcoming book. How should I refer to it? Please contact me at the e-mail given.

Sean(11:32:15) :Nice post. I have adapted your code and included it in a little helper package of mine. I hope you don’t mind.

michid(11:43:34) :Thanks ;-) Go ahead using my code. I’m happy if it is useful.

Sean(14:19:42) :Thanks dude.

Alexander Azarov(14:59:47) :There is a memoization pattern from Sygneca as well: http://scala.sygneca.com/patterns/memoization

michid(17:33:19) :Thanks for the link. However, as far as I can see that approach does not cope in the same way with recursive functions.

String Distance and Refactoring in Scala « Matt Malone’s Old-Fashioned Software Development Blog(05:19:11) :[…] the same for the same inputs. You can memoize functions in different ways. There’s a post on Michid’s Weblog about a more general solution, a memoizing class which wraps existing functions to give you a […]

Sam Merat(07:29:12) :Thanks for the post. There is also a parallel discussion in the F# for Scientists book.

0x89(23:12:19) :class Memoize1[-T, +R](f: T => R) extends (T => R) {

private[this] val vals = scala.collection.mutable.Map.empty[T, R]

def apply(x: T): R = vals.getOrElseUpdate(x, f(x))

}

Fenn Stefan(08:38:05) :A shorter Version:

case class Memoize1[-T, +R](f: T => R) extends (T => R) {

import scala.collection.mutable

private[this] val vals = mutable.Map.empty[T, R]

def apply(x: T): R = vals.getOrElseUpdate(x, f(x))

}

object RecursiveMemomizedFunction {

def apply[T, R](fRec: (T, T => R) => R): (T => R) = {

def f(n: T): R = fRec(n, n => f(n))

Memoize1(f)

}

}

object MemomizeTest2 extends Application {

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {

println(“facRec ” + n)

if (n == 0) 1 else n * f(n – 1)

}

var fac = RecursiveMemomizedFunction(facRec)

println(fac(5))

println(fac(5))

}

Scala: Hello Memo Fibonacci « Back To The Code(22:29:03) :[…] Ideas and examples are based on the blog post found here on memoization: http://michid.wordpress.com/2009/02/23/function_mem/ […]

Melvin(10:36:53) :Fenn Stefan, your version is not equivalent. The memoization function is only applied at the outermost level, meaning that it only stores the specific value that it is called with, and doesn’t store or use any other values. A simple illustration is calling the function with:

println(fac(50))

println(fac(49))

println(fac(51))

sam(16:10:55) :vals + ((x, y))

should be

vals += ((x, y))

Other than this, looks good. Great blog thanks :)

michid(16:14:41) :FIxed, thanks for finding this!

Lev Alexandrovich Neiman(22:52:11) :Why not just pass the mem map instead of passing function?

michid(14:41:57) :A map is much like a partial function. So this wouldn’t be much of a difference after all.

lcn(22:47:01) :Michid, I think for better understanding the sentance “when it comes to calculating facMem(199) there is no gain from memoization here since – although calculated before – facMem(199) is not cached” could be rewritten to “when it comes to calculating fac(199) there is no gain from memoization here since – although calculated before – fac(198) is not cached”, because the real problem in memoizing recursive function is in the body of the function it’s calling the original unmemoized version. Of course this implies what you said originally was correct – only facMem(200) is in the cache, but even when calculating a single facMem(200) instead of all the way down to 0, it’s very slow because fac(199) will be calculated 1 time, fac(198) 2 times, fac(197) 3 times, fac(196) 5 times, fac(195) 8 times, fac(194) 13 times, and so on.