λ Tony's Blog λ

Memoisation with State using Scala

Posted on February 22, 2013

Everyone has seen a naïve fibonacci implementation

object FibNaïve {
  def fibnaïve(n: BigInt): BigInt =
    if(n <= 1)
      n
    else {
      val r = fibnaïve(n - 1)
      val s = fibnaïve(n - 2)
      r + s
    }
}

While this implementation is elegant, it is exponential in time with respect to n. For example, computing the result of fibnaïve(4) will result in the unnecessary re-computation of values less than 4. If we unravel the recursion, computation occurs as follows:

  fibnaïve(4)
= fibnaïve(3) + fibnaïve(2)
= (fibnaïve(2) + fibnaïve(1)) + (fibnaïve(1) + fibnaïve(0))
= ((fibnaïve(1) + fibnaïve(0)) + fibnaïve(1)) + (fibnaïve(1) + fibnaïve(0))

This algorithm calculates for fibnaïve(2) twice, which ultimately results in a lot of repeated calculations, especially as n grows. What we would like to do is trade some space to store previous stored values for a given n. We can achieve this by looking up the argument value in a table and if it has already been computed, we return it then carry on, but if it hasn’t, we compute the result by calling fibnaïve, store it in the table, then return it. This technique is called memoisation.

As a first cut, let’s solve fibonacci with a helper function that passes a Map[BigInt, BigInt] around in the recursion. This map will serve at the memoisation table.

object FibMemo1 {
  type Memo = Map[BigInt, BigInt]

  def fibmemo1(n: BigInt): BigInt = {
    def fibmemoR(z: BigInt, memo: Memo): (BigInt, Memo) =
      if(z <= 1)
        (z, memo)
      else memo get z match {
        case None => {
          val (r, memo0) = fibmemoR(z - 1, memo)
          val (s, memo1) = fibmemoR(z - 2, memo0)
          (r + s, memo1)
        }
        case Some(v) => (v, memo)
      }

    fibmemoR(n, Map())._1
  }
}

We have traded space (the memoisation table) for speed; the algorithm is more efficient by not recomputing values. However, we have sacrificed the elegance of the code. How can we achieve both elegance and efficiency?

The State Monad

The previous code (fibmemo1) has passed state through the computation. In other words, where we once returned a value of the type A, we are accepting an argument of the type Memo and returning the pair (A, Memo). The state in this case is a value of the type Memo. We can represent this as a data structure:

case class State[S, A](run: S => (A, S))

Our fibmemoR function which once had this type:

def fibmemoR(z: BigInt, memo: Memo): (BigInt, Memo)

…can be transformed to this type:

def fibmemoR(z: BigInt): State[Memo, BigInt]

Let’s write our new fibonacci function:

object FibMemo2 {
  type Memo = Map[BigInt, BigInt]

  def fibmemo2(n: BigInt): BigInt = {
    def fibmemoR(z: BigInt): State[Memo, BigInt] =
      State(memo =>
        if(z <= 1)
          (z, memo)
        else memo get z match {
          case None => {
            val (r, memo0) = fibmemoR(z - 1) run memo
            val (s, memo1) = fibmemoR(z - 2) run memo
            (r + s, memo1)
          }
          case Some(v) => (v, memo)
        })

    fibmemoR(n).run(Map())._1
  }
}

Ew! This code is still rather clumsy as it manually passes the memo table around. What can we do about it? This is where the state monad is going to help us out. The state monad is going to take care of passing the state value around for us. The monad itself is implemented by three functions:

  1. The map method on State[S, A].

  2. The flatMap method on State[S, A].

  3. The insert function on the object State that inserts a value while leaving the state unchanged.

I will also add three convenience functions:

  1. eval method for running the State value and dropping the resulting state value.

  2. get function for taking the current state to a value. (S => A) => State[S, A]

  3. mod function for modifying the current state. (S => S) => State[S, Unit]

Here goes:

case class State[S, A](run: S => (A, S)) {
  // 1. the map method
  def map[B](f: A => B): State[S, B] =
    State(s => {
      val (a, t) = run(s)
      (f(a), t)
    })

  // 2. the flatMap method
  def flatMap[B](f: A => State[S, B]): State[S, B] =
    State(s => {
      val (a, t) = run(s)
      f(a) run t
    })

  // Convenience function to drop the resulting state value
  def eval(s: S): A =
    run(s)._1
}

object State {
  // 3. The insert function
  def insert[S, A](a: A): State[S, A] =
    State(s => (a, s))

  // Convenience function for taking the current state to a value
  def get[S, A](f: S => A): State[S, A] =
    State(s => (f(s), s))

  // Convenience function for modifying the current state
  def mod[S](f: S => S): State[S, Unit] =
    State(s => ((), f(s)))
}

We can see that the flatMap method takes care of passing the state value through a given function. This is the ultimate purpose of the state monad. Specifically, the state monad allows the programmer to pass a state (S) value through a computation (A) without us having to manually handle it. The map and insert methods complete the state monad.

How does our fibonacci implementation look now?

object FibMemo3 {
  type Memo = Map[BigInt, BigInt]

  def fibmemo3(n: BigInt): BigInt = {
    def fibmemoR(z: BigInt): State[Memo, BigInt] =
      if(z <= 1)
        State.insert(z)
      else
        for {
          u <- State.get((m: Memo) => m get z)
          v <- u map State.insert[Memo, BigInt] getOrElse
                 fibmemoR(z - 1) flatMap (r =>
                 fibmemoR(z - 2) flatMap (s => {
                 val t = r + s
                 State.mod((m: Memo) => m + ((z, t))) map (_ =>
                 t)
                 }))
        } yield v

    fibmemoR(n) eval Map()
  }
}

Now we have used the three state monad methods to pass the memo table through the computation for us - no more manual handling of passing that memo table through to successive recursive calls.

Scala provides syntax for the type of computation that chains calls to flatMap and map. It is implemented using the for and yield keywords in what is called a for-comprehension. The for-comprehension syntax will make the calls to flatMap and map, while allowing a more imperative-looking style. For example, where we once wrote code such as x flatMap (r =>, we will now write r <- x inside the for-comprehension.

How does this look?

object FibMemo4 {
  type Memo = Map[BigInt, BigInt]

  def fibmemo4(n: BigInt): BigInt = {
    def fibmemoR(z: BigInt): State[Memo, BigInt] =
      if(z <= 1)
        State.insert(z)
      else
        for {
          u <- State.get((m: Memo) => m get z)
          v <- u map State.insert[Memo, BigInt] getOrElse (for {
                 r <- fibmemoR(z - 1)
                 s <- fibmemoR(z - 2)
                 t = r + s
                 _ <- State.mod((m: Memo) => m + ((z, t)))
               } yield t)
        } yield v

    fibmemoR(n) eval Map()
  }
}

This is a lot neater as the memoisation table is handled by the state monad. In fact, it is starting to look like the original naïve solution. We are no longer manually handling the state transitions, which allows us to express the essence of the problem and without the calculation speed blow-out.

Where you once may have use var, consider if the state monad is instead more appropriate.