Function Memoization in Scala

23 02 2009

Function memoization is an optimization technique to avoid repeated calculation of function values which have been calculated by a previous evaluation of the function. In this post I show how function memoization can be implemented in Scala. Although straight forward at a first glance, effectively memoizing recursive functions requires some second thoughts. For the sake of simplicity I only discuss functions of arity one.

Assume we have a function strSqLen which calculates the square of the length of a string.

def strSqLen(s: String) = s.length*s.length

Assume further, that for some reason evaluating above function is processor intensive. One way to speed things up is to cache result values and to look them up on subsequent invocations of strSqLen. This can be done either by the function itself or by the caller. The first approach has the drawback that each and every function has to implement caching and that clients have no control over the caching mechanism. The latter approach puts all the burden on the client programmer and produces potential boilerplate. To overcome these issues we need a way to construct a memoized function having the same type as a given function.

In Scala a function of arity one is an instance of Function1. We can thus sub-class a Function[-T, +R] to Memoize1[-T, +R] which represents the memoized function. (Note the concise syntactic sugar T => R which is synonym to the more verbose Function1[T, R].)

class Memoize1[-T, +R](f: T => R) extends (T => R) {
  import scala.collection.mutable
  private[this] val vals = mutable.Map.empty[T, R]

  def apply(x: T): R = {
    if (vals.contains(x)) {
      vals(x)
    }
    else {
      val y = f(x)
      vals += ((x, y))
      y
    }
  }
}

object Memoize1 {
  def apply[T, R](f: T => R) = new Memoize1(f)
}

When applied to an argument of type T, apply checks whether the function value is in the cache. If so, it returns that value. If not, it calls the original function f, puts the function value into the cache and returns it. Note that the member vals which contains the cached function values has object private visibility (private[this]). A less restrictive visibility would cause type checking to fail since mutable.Map[A, B] is invariant in both its type arguments while Memoize1[-T, +R] is contravariant in T and covariant in R (as is Function1[-T, +R]). However since vals is accessed from its containing instance only, it cannot cause problems with variance. We tell this to the compiler by declaring the field object private.

We can now easily create and use a memoization of strSqLen like this:

val strSqLenMemoized = Memoize1(strSqLen)
val a = strSqLenMemoized("hello Memo")
val b = strSqLen("hello Memo")
assert(a == b)

Going recursive

Memoization of recursive functions is possible but might not expose the desired effect. Consider the following recursive implementation of the factorial function.

def fac(n: BigInt): BigInt = {
  if (n == 0) 1
  else n*fac(n - 1) 
}

Calculating the factorials from 200 down to 0 does not take advantage of the memoization at all.

val facMem = Memoize1(fac)
    
for (k <- 200 to 0 by -1) { 
  println(facMem(k))
}

While facMem(200) caches the factorial of 200, it does not cache the results of its intermediate recursive invocations since the recursive calls invoke the original function instead of the memoized. When it comes to calculating facMem(199) there is no gain from memoization here since – although calculated before – facMem(199) is not cached.

To improve this we need a more flexible implementation of fac. We want its recursive calls to be available on the outside such that they are available for caching.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
  if (n == 0) 1
  else n*f(n - 1) 
}

Here the caller needs to pass a function for calculating factorials to factRec. She is therefore free to pass a memoized function to speed up recurrent invocations. However, it seems we are back to square one: we need to pass a function for calculating factorials do facRec in order for it to calculate factorials. As it turns out this is not the case. By passing facRec to itself recursively, we can construct the desired factorial function.

var fac: BigInt => BigInt = null 
fac = facRec(_, fac(_))

First we declare function fac to map from BigInt to BigInt. Then we partially apply facRec to fac which yields the desired factorial function.

We can generalize and factor out this construction process into an object Y like this:

object Y {
  def apply[T, R](f: (T, T => R) => R): (T => R) = {
    var yf: T => R = null
    yf = f(_, yf(_))
    yf
  }
}

Along the same lines we can provide functionality for creating memoized versions of recursive functions. Instead of just recursively invoking the passed function, we pass a memoized version of it.

object Memoize1 {
  // ... same as above

  def Y[T, R](f: (T, T => R) => R) = {
    var yf: T => R = null
    yf = Memoize1(f(_, yf(_)))
    yf
  }
}

Using Memoize1.Y we can calculate the factorials from 200 down to 0 while taking full advantage of memoization of all intermediate recursive invocation.

def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
  if (n == 0) 1
  else n*f(n - 1) 
}
    
val fac = Memoize1.Y(facRec)
    
for (k &lt;- 200 to 0 by -1) 
  println(fac(k))
About these ads

Actions

Information

23 responses

24 02 2009
HRJ

Thanks for the post! Well written.

I feel this should be made available in the std libs.

24 02 2009
Daniel Spiewak

I tend to prefer a version of the Y-combinator which requires a curried function as its parameter. I think this leads to a slightly cleaner syntax at times:

def apply[T, R](f: (T => R) => (T) => R): (T => R) = {
var yf: T => R = null
yf = f(_, yf(_))
yf
}

def facRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
if (n == 0) 1
else n*f(n – 1)
}

Now we can be arguably more clean (though, less explicit) in our non-memoized fixpoint for facRec:

val fac: BigInt=>BigInt = facRec(fac(_))

In this, the fixpoint is implicit in Scala’s definition of a value, obviating the need to define it separately in the Y object.

24 02 2009
michid

Daniel,

Tanks for the feedback. Using curried functions is cleaner indeed. I should have thought of that ;-)

The less explicit version of the fixpoint gives me some troubles.

val fac: BigInt => BigInt = facRec(fac)(_)

results in the Compiler complaining forward reference extends over definition of value fac.

Using def instead of val works however:

def fac: BigInt => BigInt = facRec(fac)(_)

This helps getting rid of the var Y:


object Y {
def apply[T, R](f: (T => R) => T => R): (T => R) = {
def yf: T => R = f(yf)(_)
yf
}
}

7 09 2010
Aaron Novstrup

Note that using `def yf` doesn’t work in the memoized version (it results in a new memoized function being generated for each recursive call, and therefore effectively no memoization). `lazy val yf` will work though

16 12 2014
lcn

Per http://mvanier.livejournal.com/2897.html, the impure `Y` implemented here “will only work in a lazy language”, which might explain why `lazy val` is needed here.

17 12 2014
lcn

Ignore my comment, I made a mistake about `Y` and `yf`. Though `yf` is cheating (see discussion at http://goo.gl/TJMO2h), `Y` doesn’t have any problem.

2 03 2009
asyropoulos

I would like to mention your work on memo functions in a forthcoming book. How should I refer to it? Please contact me at the e-mail given.

11 06 2009
Sean

Nice post. I have adapted your code and included it in a little helper package of mine. I hope you don’t mind.

11 06 2009
michid

Thanks ;-) Go ahead using my code. I’m happy if it is useful.

11 06 2009
Sean

Thanks dude.

24 08 2009
Alexander Azarov

There is a memoization pattern from Sygneca as well: http://scala.sygneca.com/patterns/memoization

24 08 2009
michid

Thanks for the link. However, as far as I can see that approach does not cope in the same way with recursive functions.

19 11 2009
String Distance and Refactoring in Scala « Matt Malone’s Old-Fashioned Software Development Blog

[…] the same for the same inputs. You can memoize functions in different ways. There’s a post on Michid’s Weblog about a more general solution, a memoizing class which wraps existing functions to give you a […]

28 04 2010
Sam Merat

Thanks for the post. There is also a parallel discussion in the F# for Scientists book.

4 08 2010
0x89

class Memoize1[-T, +R](f: T => R) extends (T => R) {
private[this] val vals = scala.collection.mutable.Map.empty[T, R]
def apply(x: T): R = vals.getOrElseUpdate(x, f(x))
}

28 02 2011
Fenn Stefan

A shorter Version:

case class Memoize1[-T, +R](f: T => R) extends (T => R) {
import scala.collection.mutable
private[this] val vals = mutable.Map.empty[T, R]

def apply(x: T): R = vals.getOrElseUpdate(x, f(x))
}

object RecursiveMemomizedFunction {
def apply[T, R](fRec: (T, T => R) => R): (T => R) = {
def f(n: T): R = fRec(n, n => f(n))
Memoize1(f)
}
}

object MemomizeTest2 extends Application {
def facRec(n: BigInt, f: BigInt => BigInt): BigInt = {
println(“facRec ” + n)
if (n == 0) 1 else n * f(n – 1)
}
var fac = RecursiveMemomizedFunction(facRec)

println(fac(5))
println(fac(5))
}

29 05 2011
Scala: Hello Memo Fibonacci « Back To The Code

[…] Ideas and examples are based on the blog post found here on memoization: http://michid.wordpress.com/2009/02/23/function_mem/ […]

5 02 2012
Melvin

Fenn Stefan, your version is not equivalent. The memoization function is only applied at the outermost level, meaning that it only stores the specific value that it is called with, and doesn’t store or use any other values. A simple illustration is calling the function with:
println(fac(50))
println(fac(49))
println(fac(51))

27 09 2013
sam

vals + ((x, y))

should be

vals += ((x, y))

Other than this, looks good. Great blog thanks :)

27 09 2013
michid

FIxed, thanks for finding this!

5 12 2014
Lev Alexandrovich Neiman

Why not just pass the mem map instead of passing function?

7 12 2014
michid

A map is much like a partial function. So this wouldn’t be much of a difference after all.

16 12 2014
lcn

Michid, I think for better understanding the sentance “when it comes to calculating facMem(199) there is no gain from memoization here since – although calculated before – facMem(199) is not cached” could be rewritten to “when it comes to calculating fac(199) there is no gain from memoization here since – although calculated before – fac(198) is not cached”, because the real problem in memoizing recursive function is in the body of the function it’s calling the original unmemoized version. Of course this implies what you said originally was correct – only facMem(200) is in the cache, but even when calculating a single facMem(200) instead of all the way down to 0, it’s very slow because fac(199) will be calculated 1 time, fac(198) 2 times, fac(197) 3 times, fac(196) 5 times, fac(195) 8 times, fac(194) 13 times, and so on.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s




Follow

Get every new post delivered to your Inbox.

%d bloggers like this: