## Function Memoization in Scala

23 02 2009

Function memoization is an optimization technique to avoid repeated calculation of function values which have been calculated by a previous evaluation of the function. In this post I show how function memoization can be implemented in Scala. Although straight forward at a first glance, effectively memoizing recursive functions requires some second thoughts. For the sake of simplicity I only discuss functions of arity one.

Assume we have a function strSqLen which calculates the square of the length of a string.

```def strSqLen(s: String) = s.length*s.length
```

Assume further, that for some reason evaluating above function is processor intensive. One way to speed things up is to cache result values and to look them up on subsequent invocations of strSqLen. This can be done either by the function itself or by the caller. The first approach has the drawback that each and every function has to implement caching and that clients have no control over the caching mechanism. The latter approach puts all the burden on the client programmer and produces potential boilerplate. To overcome these issues we need a way to construct a memoized function having the same type as a given function.

In Scala a function of arity one is an instance of Function1. We can thus sub-class a Function[-T, +R] to Memoize1[-T, +R] which represents the memoized function. (Note the concise syntactic sugar T => R which is synonym to the more verbose Function1[T, R].)

```class Memoize1[-T, +R](f: T => R) extends (T => R) {
import scala.collection.mutable
private[this] val vals = mutable.Map.empty[T, R]

def apply(x: T): R = {
if (vals.contains(x)) {
vals(x)
}
else {
val y = f(x)
vals += ((x, y))
y
}
}
}

object Memoize1 {
def apply[T, R](f: T => R) = new Memoize1(f)
}
```

When applied to an argument of type T, apply checks whether the function value is in the cache. If so, it returns that value. If not, it calls the original function f, puts the function value into the cache and returns it. Note that the member vals which contains the cached function values has object private visibility (private[this]). A less restrictive visibility would cause type checking to fail since mutable.Map[A, B] is invariant in both its type arguments while Memoize1[-T, +R] is contravariant in T and covariant in R (as is Function1[-T, +R]). However since vals is accessed from its containing instance only, it cannot cause problems with variance. We tell this to the compiler by declaring the field object private.

We can now easily create and use a memoization of strSqLen like this:

```val strSqLenMemoized = Memoize1(strSqLen)
val a = strSqLenMemoized("hello Memo")
val b = strSqLen("hello Memo")
assert(a == b)
```

### Going recursive

Memoization of recursive functions is possible but might not expose the desired effect. Consider the following recursive implementation of the factorial function.

```def fac(n: BigInt): BigInt = {
if (n == 0) 1
else n*fac(n - 1)
}
```

Calculating the factorials from 200 down to 0 does not take advantage of the memoization at all.

```val facMem = Memoize1(fac)

for (k <- 200 to 0 by -1) {
println(facMem(k))
}
```

While facMem(200) caches the factorial of 200, it does not cache the results of its intermediate recursive invocations since the recursive calls invoke the original function instead of the memoized. When it comes to calculating facMem(199) there is no gain from memoization here since – although calculated before – facMem(199) is not cached.

To improve this we need a more flexible implementation of fac. We want its recursive calls to be available on the outside such that they are available for caching.

```def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
if (n == 0) 1
else n*f(n - 1)
}
```

Here the caller needs to pass a function for calculating factorials to factRec. She is therefore free to pass a memoized function to speed up recurrent invocations. However, it seems we are back to square one: we need to pass a function for calculating factorials do facRec in order for it to calculate factorials. As it turns out this is not the case. By passing facRec to itself recursively, we can construct the desired factorial function.

```var fac: BigInt => BigInt = null
fac = facRec(_, fac(_))
```

First we declare function fac to map from BigInt to BigInt. Then we partially apply facRec to fac which yields the desired factorial function.

We can generalize and factor out this construction process into an object Y like this:

```object Y {
def apply[T, R](f: (T, T => R) => R): (T => R) = {
var yf: T => R = null
yf = f(_, yf(_))
yf
}
}
```

Along the same lines we can provide functionality for creating memoized versions of recursive functions. Instead of just recursively invoking the passed function, we pass a memoized version of it.

```object Memoize1 {
// ... same as above

def Y[T, R](f: (T, T => R) => R) = {
var yf: T => R = null
yf = Memoize1(f(_, yf(_)))
yf
}
}
```

Using Memoize1.Y we can calculate the factorials from 200 down to 0 while taking full advantage of memoization of all intermediate recursive invocation.

```def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
if (n == 0) 1
else n*f(n - 1)
}

val fac = Memoize1.Y(facRec)

for (k &lt;- 200 to 0 by -1)
println(fac(k))
```

### 29 responses

24 02 2009

Thanks for the post! Well written.

I feel this should be made available in the std libs.

24 02 2009

I tend to prefer a version of the Y-combinator which requires a curried function as its parameter. I think this leads to a slightly cleaner syntax at times:

def apply[T, R](f: (T => R) => (T) => R): (T => R) = {
var yf: T => R = null
yf = f(_, yf(_))
yf
}

def facRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
if (n == 0) 1
else n*f(n – 1)
}

Now we can be arguably more clean (though, less explicit) in our non-memoized fixpoint for facRec:

val fac: BigInt=>BigInt = facRec(fac(_))

In this, the fixpoint is implicit in Scala’s definition of a value, obviating the need to define it separately in the Y object.

24 02 2009

Daniel,

Tanks for the feedback. Using curried functions is cleaner indeed. I should have thought of that 😉

The less explicit version of the fixpoint gives me some troubles.

`val fac: BigInt => BigInt = facRec(fac)(_)`

results in the Compiler complaining forward reference extends over definition of value fac.

Using def instead of val works however:

`def fac: BigInt => BigInt = facRec(fac)(_)`

This helps getting rid of the var Y:

``` object Y { def apply[T, R](f: (T => R) => T => R): (T => R) = { def yf: T => R = f(yf)(_) yf } } ```

7 09 2010

Note that using `def yf` doesn’t work in the memoized version (it results in a new memoized function being generated for each recursive call, and therefore effectively no memoization). `lazy val yf` will work though

16 12 2014

Per http://mvanier.livejournal.com/2897.html, the impure `Y` implemented here “will only work in a lazy language”, which might explain why `lazy val` is needed here.

17 12 2014

Ignore my comment, I made a mistake about `Y` and `yf`. Though `yf` is cheating (see discussion at http://goo.gl/TJMO2h), `Y` doesn’t have any problem.

24 08 2015

Thanks for the awesome post!

And thanks Daniel Spiewak for your input.

I have created a memoized fibonacci on the basis of this post – input greatly appreciated as I am new to Scala and I know this is a bad implementation.

http://pastebin.com/S4aR8VZV

2 03 2009

I would like to mention your work on memo functions in a forthcoming book. How should I refer to it? Please contact me at the e-mail given.

11 06 2009

Nice post. I have adapted your code and included it in a little helper package of mine. I hope you don’t mind.

11 06 2009

Thanks 😉 Go ahead using my code. I’m happy if it is useful.

11 06 2009

Thanks dude.

24 08 2009

There is a memoization pattern from Sygneca as well: http://scala.sygneca.com/patterns/memoization

24 08 2009

Thanks for the link. However, as far as I can see that approach does not cope in the same way with recursive functions.

19 11 2009

[…] the same for the same inputs. You can memoize functions in different ways. There’s a post on Michid’s Weblog about a more general solution, a memoizing class which wraps existing functions to give you a […]

28 04 2010

Thanks for the post. There is also a parallel discussion in the F# for Scientists book.

4 08 2010

class Memoize1[-T, +R](f: T => R) extends (T => R) {
private[this] val vals = scala.collection.mutable.Map.empty[T, R]
def apply(x: T): R = vals.getOrElseUpdate(x, f(x))
}

28 02 2011

A shorter Version:

case class Memoize1[-T, +R](f: T => R) extends (T => R) {
import scala.collection.mutable
private[this] val vals = mutable.Map.empty[T, R]

def apply(x: T): R = vals.getOrElseUpdate(x, f(x))
}

object RecursiveMemomizedFunction {
def apply[T, R](fRec: (T, T => R) => R): (T => R) = {
def f(n: T): R = fRec(n, n => f(n))
Memoize1(f)
}
}

object MemomizeTest2 extends Application {
def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
println(“facRec ” + n)
if (n == 0) 1 else n * f(n – 1)
}
var fac = RecursiveMemomizedFunction(facRec)

println(fac(5))
println(fac(5))
}

29 05 2011

[…] Ideas and examples are based on the blog post found here on memoization: https://michid.wordpress.com/2009/02/23/function_mem/ […]

5 02 2012

Fenn Stefan, your version is not equivalent. The memoization function is only applied at the outermost level, meaning that it only stores the specific value that it is called with, and doesn’t store or use any other values. A simple illustration is calling the function with:
println(fac(50))
println(fac(49))
println(fac(51))

27 09 2013

vals + ((x, y))

should be

vals += ((x, y))

Other than this, looks good. Great blog thanks 🙂

27 09 2013

FIxed, thanks for finding this!

5 12 2014

Why not just pass the mem map instead of passing function?

7 12 2014

A map is much like a partial function. So this wouldn’t be much of a difference after all.

16 12 2014

Michid, I think for better understanding the sentance “when it comes to calculating facMem(199) there is no gain from memoization here since – although calculated before – facMem(199) is not cached” could be rewritten to “when it comes to calculating fac(199) there is no gain from memoization here since – although calculated before – fac(198) is not cached”, because the real problem in memoizing recursive function is in the body of the function it’s calling the original unmemoized version. Of course this implies what you said originally was correct – only facMem(200) is in the cache, but even when calculating a single facMem(200) instead of all the way down to 0, it’s very slow because fac(199) will be calculated 1 time, fac(198) 2 times, fac(197) 3 times, fac(196) 5 times, fac(195) 8 times, fac(194) 13 times, and so on.

22 03 2015

Hi Michid, I’m just about to write an article about FP building blocks in different programming languages on my blog (non commercial) and I’d like to know if it’s ok for you if I referenced your memoization example there?

Cheers,

Micha

23 03 2015