## Memoisation with State using Scala

Posted on February 22, 2013

Everyone has seen a naïve fibonacci implementation

``````object FibNaïve {
def fibnaïve(n: BigInt): BigInt =
if(n <= 1)
n
else {
val r = fibnaïve(n - 1)
val s = fibnaïve(n - 2)
r + s
}
}``````

While this implementation is elegant, it is exponential in time with respect to `n`. For example, computing the result of `fibnaïve(4)` will result in the unnecessary re-computation of values less than `4`. If we unravel the recursion, computation occurs as follows:

``````  fibnaïve(4)
= fibnaïve(3) + fibnaïve(2)
= (fibnaïve(2) + fibnaïve(1)) + (fibnaïve(1) + fibnaïve(0))
= ((fibnaïve(1) + fibnaïve(0)) + fibnaïve(1)) + (fibnaïve(1) + fibnaïve(0))``````

This algorithm calculates for `fibnaïve(2)` twice, which ultimately results in a lot of repeated calculations, especially as `n` grows. What we would like to do is trade some space to store previous stored values for a given `n`. We can achieve this by looking up the argument value in a table and if it has already been computed, we return it then carry on, but if it hasn’t, we compute the result by calling `fibnaïve`, store it in the table, then return it. This technique is called memoisation.

As a first cut, let’s solve fibonacci with a helper function that passes a `Map[BigInt, BigInt]` around in the recursion. This map will serve at the memoisation table.

``````object FibMemo1 {
type Memo = Map[BigInt, BigInt]

def fibmemo1(n: BigInt): BigInt = {
def fibmemoR(z: BigInt, memo: Memo): (BigInt, Memo) =
if(z <= 1)
(z, memo)
else memo get z match {
case None => {
val (r, memo0) = fibmemoR(z - 1, memo)
val (s, memo1) = fibmemoR(z - 2, memo0)
(r + s, memo1)
}
case Some(v) => (v, memo)
}

fibmemoR(n, Map())._1
}
}``````

We have traded space (the memoisation table) for speed; the algorithm is more efficient by not recomputing values. However, we have sacrificed the elegance of the code. How can we achieve both elegance and efficiency?

The previous code (`fibmemo1`) has passed state through the computation. In other words, where we once returned a value of the type `A`, we are accepting an argument of the type `Memo` and returning the pair `(A, Memo)`. The state in this case is a value of the type `Memo`. We can represent this as a data structure:

``case class State[S, A](run: S => (A, S))``

Our `fibmemoR` function which once had this type:

``def fibmemoR(z: BigInt, memo: Memo): (BigInt, Memo)``

…can be transformed to this type:

``def fibmemoR(z: BigInt): State[Memo, BigInt]``

Let’s write our new fibonacci function:

``````object FibMemo2 {
type Memo = Map[BigInt, BigInt]

def fibmemo2(n: BigInt): BigInt = {
def fibmemoR(z: BigInt): State[Memo, BigInt] =
State(memo =>
if(z <= 1)
(z, memo)
else memo get z match {
case None => {
val (r, memo0) = fibmemoR(z - 1) run memo
val (s, memo1) = fibmemoR(z - 2) run memo
(r + s, memo1)
}
case Some(v) => (v, memo)
})

fibmemoR(n).run(Map())._1
}
}``````

Ew! This code is still rather clumsy as it manually passes the memo table around. What can we do about it? This is where the state monad is going to help us out. The state monad is going to take care of passing the state value around for us. The monad itself is implemented by three functions:

1. The `map` method on `State[S, A]`.

2. The `flatMap` method on `State[S, A]`.

3. The `insert` function on the `object State` that inserts a value while leaving the state unchanged.

I will also add three convenience functions:

1. `eval` method for running the `State` value and dropping the resulting state value.

2. `get` function for taking the current state to a value. `(S => A) => State[S, A]`

3. `mod` function for modifying the current state. `(S => S) => State[S, Unit]`

Here goes:

``````case class State[S, A](run: S => (A, S)) {
// 1. the map method
def map[B](f: A => B): State[S, B] =
State(s => {
val (a, t) = run(s)
(f(a), t)
})

// 2. the flatMap method
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (a, t) = run(s)
f(a) run t
})

// Convenience function to drop the resulting state value
def eval(s: S): A =
run(s)._1
}

object State {
// 3. The insert function
def insert[S, A](a: A): State[S, A] =
State(s => (a, s))

// Convenience function for taking the current state to a value
def get[S, A](f: S => A): State[S, A] =
State(s => (f(s), s))

// Convenience function for modifying the current state
def mod[S](f: S => S): State[S, Unit] =
State(s => ((), f(s)))
}``````

We can see that the `flatMap` method takes care of passing the state value through a given function. This is the ultimate purpose of the state monad. Specifically, the state monad allows the programmer to pass a state (`S`) value through a computation (`A`) without us having to manually handle it. The `map` and `insert` methods complete the state monad.

How does our fibonacci implementation look now?

``````object FibMemo3 {
type Memo = Map[BigInt, BigInt]

def fibmemo3(n: BigInt): BigInt = {
def fibmemoR(z: BigInt): State[Memo, BigInt] =
if(z <= 1)
State.insert(z)
else
for {
u <- State.get((m: Memo) => m get z)
v <- u map State.insert[Memo, BigInt] getOrElse
fibmemoR(z - 1) flatMap (r =>
fibmemoR(z - 2) flatMap (s => {
val t = r + s
State.mod((m: Memo) => m + ((z, t))) map (_ =>
t)
}))
} yield v

fibmemoR(n) eval Map()
}
}``````

Now we have used the three state monad methods to pass the memo table through the computation for us - no more manual handling of passing that memo table through to successive recursive calls.

Scala provides syntax for the type of computation that chains calls to `flatMap` and `map`. It is implemented using the `for` and `yield` keywords in what is called a for-comprehension. The for-comprehension syntax will make the calls to `flatMap` and `map`, while allowing a more imperative-looking style. For example, where we once wrote code such as `x flatMap (r =>`, we will now write `r <- x` inside the for-comprehension.

How does this look?

``````object FibMemo4 {
type Memo = Map[BigInt, BigInt]

def fibmemo4(n: BigInt): BigInt = {
def fibmemoR(z: BigInt): State[Memo, BigInt] =
if(z <= 1)
State.insert(z)
else
for {
u <- State.get((m: Memo) => m get z)
v <- u map State.insert[Memo, BigInt] getOrElse (for {
r <- fibmemoR(z - 1)
s <- fibmemoR(z - 2)
t = r + s
_ <- State.mod((m: Memo) => m + ((z, t)))
} yield t)
} yield v

fibmemoR(n) eval Map()
}
}``````

This is a lot neater as the memoisation table is handled by the state monad. In fact, it is starting to look like the original naïve solution. We are no longer manually handling the state transitions, which allows us to express the essence of the problem and without the calculation speed blow-out.

Where you once may have use `var`, consider if the state monad is instead more appropriate.