λ Tony's Blog λ
The State Monad for Scala users
Posted on November 3, 2008Scalaz 3.2 includes support for a State data type for the Scala Programming Language. This data type is a monad and thus supports flatMap
and can be used in a for-comprehension.
Below I will give a practical demonstration of why you might choose to use the State data type as a monad.
Consider a binary leaf tree data type:
sealed abstract class Tree[+A]
final case class Leaf[A](a: A) extends Tree[A]
final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
This is pretty simple so far. Now suppose we wanted to map a function across a Tree
from left to right where the result depended on the previous result. For this example, consider that we wanted to number each leaf by adding 1 as we traverse left to right (at a given starting value). That is, the result depends on the previous result because we must add 1 to that previous result.
We might implement this by passing in the integer value and returning the pair of the Tree
as it is constructed and the integer value as it increments. Such an implementation might look like this:
sealed abstract class Tree[+A] {
def number(seed: Int): (Tree[(A, Int)], Int) = this match {
case Leaf(x) => (Leaf(x, seed), seed + 1)
case Branch(left, right) => left number seed match {
case (l, ls) => {
match {
right number ls case (r, rs) => (Branch(l, r), rs)
}
}
}
} }
This code is pretty messy and it would become even messier with a less trivial example to apply across the Tree
.
The State
data type is effectively a Function[S, (S, A)]
and the monad instance runs across the Function[_, (_, A)]
part. That is to say, the flatMap
signature looks roughly like this:
trait State[S, A] {
val s: Function[S, (S, A)] // abstract
def flatMap[B](f: A => Function[S, (S, B)]): Function[S, (S, B)]
}
You might consider filling out this method signature for fun :)
The flatMap
function allows the user to make the state change implicit rather than explicit (and messy!). The State
data type includes a few useful methods and functions. I will only use two of those functions below; init
, which constructs a State
object that has computed the state itself. For example, going with the analogy to Function[S, (S, A)]
, the init
function looks like this: s => (s, s)
and so returns a State[S, S]
. The second function is modify
, which applies a transform the state and it is intended to ignore the computed value (Unit
).
Here is our new implementation:
import scalaz.State
import scalaz.State._
def numbers: State[Int, Tree[(A, Int)]] = this match {
case Leaf(x) => for(s <- init[Int];
modify((_: Int) + 1))
n <- yield Leaf((x, s + 1))
case Branch(left, right) => for(l <- left.numbers;
numbers)
r <- right.yield Branch(l, r)
}
This is much neater and hides the otherwise explicit recursive application of adding 1 to an integer. Following is a complete source file that can be compiled successfully against Scalaz 3.2 using the latest version of Scala.
sealed abstract class Tree[+A] {
def number(seed: Int): (Tree[(A, Int)], Int) = this match {
case Leaf(x) => (Leaf(x, seed), seed + 1)
case Branch(left, right) => left number seed match {
case (l, ls) => {
match {
right number ls case (r, rs) => (Branch(l, r), rs)
}
}
}
}
import scalaz.State
import scalaz.State._
def numbers: State[Int, Tree[(A, Int)]] = this match {
case Leaf(x) => for(s <- init[Int];
modify((_: Int) + 1))
n <- yield Leaf((x, s + 1))
case Branch(left, right) => for(l <- left.numbers;
numbers)
r <- right.yield Branch(l, r)
}
}final case class Leaf[A](a: A) extends Tree[A]
final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
And if you find Haskell easier to read, this might help too. It is a roughly equivalent program using GHC’s built-in State
data type.
import Control.Monad.State.Lazy
data Tree a = Leaf a | Branch (Tree a) (Tree a)
number :: Int -> Tree a -> (Tree (a, Int),Int)
Leaf a) = (Leaf (a, seed), seed + 1)
number seed (Branch left right)
number seed (= let (l, ls) = number seed left
= number ls right
(r, rs) in
Branch l r, rs)
(
numbers :: Tree a -> State Int (Tree (a, Int))
Leaf a) = do n <- get
numbers (+1)
modify (return (Leaf (a, n))
Branch l r) = do left <- numbers l
numbers (<- numbers r
right return (Branch left right)
initState :: State s s
= State (\s -> (s, s)) initState