##### λ Tony's Blog λ

## The State Monad for Scala users

Posted on November 3, 2008Scalaz 3.2 includes support for a State data type for the Scala Programming Language. This data type is a monad and thus supports `flatMap`

and can be used in a for-comprehension.

Below I will give a practical demonstration of why you might choose to use the State data type as a monad.

Consider a binary leaf tree data type:

```
sealed abstract class Tree[+A]
final case class Leaf[A](a: A) extends Tree[A]
final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
```

This is pretty simple so far. Now suppose we wanted to map a function across a `Tree`

from left to right **where the result depended on the previous result**. For this example, consider that we wanted to number each leaf by adding 1 as we traverse left to right (at a given starting value). That is, the result depends on the previous result because we must add 1 to that previous result.

We might implement this by passing in the integer value and returning the pair of the `Tree`

as it is constructed and the integer value as it increments. Such an implementation might look like this:

```
sealed abstract class Tree[+A] {
def number(seed: Int): (Tree[(A, Int)], Int) = this match {
case Leaf(x) => (Leaf(x, seed), seed + 1)
case Branch(left, right) => left number seed match {
case (l, ls) => {
right number ls match {
case (r, rs) => (Branch(l, r), rs)
}
}
}
}
}
```

This code is pretty messy and it would become even messier with a less trivial example to apply across the `Tree`

.

The `State`

data type is effectively a `Function[S, (S, A)]`

and the monad instance runs across the `Function[_, (_, A)]`

part. That is to say, the `flatMap`

signature looks roughly like this:

```
trait State[S, A] {
val s: Function[S, (S, A)] // abstract
def flatMap[B](f: A => Function[S, (S, B)]): Function[S, (S, B)]
}
```

You might consider filling out this method signature for fun :)

The `flatMap`

function allows the user to make the state change implicit rather than explicit (and messy!). The `State`

data type includes a few useful methods and functions. I will only use two of those functions below; `init`

, which constructs a `State`

object that has computed the state itself. For example, going with the analogy to `Function[S, (S, A)]`

, the `init`

function looks like this: `s => (s, s)`

and so returns a `State[S, S]`

. The second function is `modify`

, which applies a transform the state and it is intended to ignore the computed value (`Unit`

).

Here is our new implementation:

```
import scalaz.State
import scalaz.State._
def numbers: State[Int, Tree[(A, Int)]] = this match {
case Leaf(x) => for(s <- init[Int];
n <- modify((_: Int) + 1))
yield Leaf((x, s + 1))
case Branch(left, right) => for(l <- left.numbers;
r <- right.numbers)
yield Branch(l, r)
}
```

This is much neater and hides the otherwise explicit recursive application of adding 1 to an integer. Following is a complete source file that can be compiled successfully against Scalaz 3.2 using the latest version of Scala.

```
sealed abstract class Tree[+A] {
def number(seed: Int): (Tree[(A, Int)], Int) = this match {
case Leaf(x) => (Leaf(x, seed), seed + 1)
case Branch(left, right) => left number seed match {
case (l, ls) => {
right number ls match {
case (r, rs) => (Branch(l, r), rs)
}
}
}
}
import scalaz.State
import scalaz.State._
def numbers: State[Int, Tree[(A, Int)]] = this match {
case Leaf(x) => for(s <- init[Int];
n <- modify((_: Int) + 1))
yield Leaf((x, s + 1))
case Branch(left, right) => for(l <- left.numbers;
r <- right.numbers)
yield Branch(l, r)
}
}
final case class Leaf[A](a: A) extends Tree[A]
final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
```

And if you find Haskell easier to read, this might help too. It is a roughly equivalent program using GHC’s built-in `State`

data type.

```
import Control.Monad.State.Lazy
data Tree a = Leaf a | Branch (Tree a) (Tree a)
number :: Int -> Tree a -> (Tree (a, Int),Int)
number seed (Leaf a) = (Leaf (a, seed), seed + 1)
number seed (Branch left right)
= let (l, ls) = number seed left
(r, rs) = number ls right
in
(Branch l r, rs)
numbers :: Tree a -> State Int (Tree (a, Int))
numbers (Leaf a) = do n <- get
modify (+1)
return (Leaf (a, n))
numbers (Branch l r) = do left <- numbers l
right <- numbers r
return (Branch left right)
initState :: State s s
initState = State (\s -> (s, s))
```