Stateful Functions

Motivating Example: Stacks

Let’s define the following to represent and manipulate stacks, a data structure that allows adding and removing element only from the top end:

type Stack a = [a]

push_ :: a -> Stack a -> Stack a     -- first version
pop   :: Stack a -> (a, Stack)

push_ a as = a:as
pop (a:as) = (a, as)

We can test our Stack with a sequence of operations:

testStack_ :: Stack Int -> (Int, Stack Int)
testStack_ s0 =
  let
    s1 = push_ 0 s0
    s2 = push_ 1 s1
    s3 = push_ 2 s2
    s4 = push_ 3 s3
    (a,s5) = pop s4
    (b,s6) = pop s5
    s7 = push_ (a+b) s6
  in
    pop s7

> testStack_ []
(5,[1,0])

The stack is “threaded through” a series of computations. Let’s make this even more explicit:

push :: a -> Stack a -> ((), Stack a)
push a as = ((), a:is)

testStack :: Stack Int -> (Int, Stack Int)
testStack s0 =
  let
    (_,s1) = push 0 s0
    (_,s2) = push 1 s1
    (_,s3) = push 2 s2
    (_,s4) = push 3 s3
    (a,s5) = pop s4
    (b,s6) = pop s5
    (_,s7) = push (a+b) s6
  in
    pop s7

What’s wrong with this? First of all, there’s a lot of syntactic “noise.” Worse yet, it’s easy to make a mistake; the intention is that each version of the stack si is referred to once and “discarded” in favor of the new stack returned by the operation.

Sequenced Computations or “Stateful Functions”

There is a common idiom in the code above: the threading of some object through a series of computations.

computation :: StateObject -> (T7, StateObject)
computation s0 =
  let
    (a1,s1) = f0 s0      -- f0 :: StateObject -> (T1, StateObject)
    (a2,s2) = f1 s1      -- f1 :: StateObject -> (T2, StateObject)
    (a3,s3) = f2 s2      -- f2 :: StateObject -> (T3, StateObject)
    (a4,s4) = f3 s3      -- f3 :: StateObject -> (T4, StateObject)
    (a5,s5) = f4 s4      -- f4 :: StateObject -> (T5, StateObject)
    (a6,s6) = f5 s5      -- f5 :: StateObject -> (T6, StateObject)
    (a7,s7) = f6 s6      -- f6 :: StateObject -> (T7, StateObject)
  in
    (a7,s7)

We refer to this object as “the state” even though, if you’re familiar with other languages with “mutable” or “stateful” features, there’s nothing like that here. Just ordinary pure functions, with a pattern of use that feels like we’re manipulating state.

type StateFunc s a =
  s -> (a,s)                 -- name for function type idiom

data StateFunc s a =
  StateFunc (s -> (a,s))     -- new datatype to define instances

newtype StateFunc s a =
  StateFunc (s -> (a,s))     -- newtype b/c one, one-arg constructor

newtype StateFunc s a =
  StateFunc { runStateFunc :: s -> (a,s) }  -- ... field for unboxing

StateFunc s a means a computation that, starting with an input state of type s, produces a value of type a and an updated state of type s.

When you read StateFunc s a, think “function of type s -> (a,s)” (but boxed up in a newtype). Or think “stateful computation that produces an a” keeping in mind that there is input and output state “in the background.”

To preview the benefits that this abstraction will provide, we are going to define

pop'  :: State (Stack a) a
push' :: a -> State (Stack a) ()

and, then, because StateFunc s is a Monad, we will write the previous sequence of stack operations as:

testStack' = do
  push' 0
  push' 1
  push' 2
  push' 3
  a <- pop'
  b <- pop'
  push' (a+b)
  pop'

Sequencing Stateful Functions

We can think of sequencing StateFuncs as function composition, with the appropriate plumbing to thread the state objects through the component functions.

 s0   ------  s1   ------  s2   ------  s3
----> | f0 | ----> | f1 | ----> | f2 | ---->
      ------       ------       ------
          \_________/  \_________/  \------>
            a1           a2           a3

   f0                               :: s -> (a1, s)
   f1                               :: s -> (a2, s)
   f2                               :: s -> (a3, s)
  
   f0 >>= \a1 -> f1 >>= \a2 -> f2   :: s -> (a3, s)

So, let’s define how StateFunc s forms a monadic, applicative functor.

fmap         :: (a -> b) -> StateFunc s a -> StateFunc s b
(<*>)        :: StateFunc s (a -> b) -> StateFunc s a -> StateFunc s b
(>>=)        :: StateFunc s a -> (a -> StateFunc s b) -> StateFunc s b
pure, return :: a -> StateFunc s a

Let’s define the Monad instance, and then derive free instances for Functor and Applicative. (Alternatively, to develop the intuition for how stateful functions work more slowly, define the instances “in order”.)

instance Monad (StateFunc s) where
 -- return :: a -> StateFunc s a
    return a = StateFunc $ \s -> (a, s)

 -- (>>=) :: StateFunc s a -> (a -> StateFunc s b) -> StateFunc s b
    sa >>= f = StateFunc $ \s0 ->
      let
        (a, s1) = runStateFunc sa s0
        (b, s2) = runStateFunc (f a) s1
      in
        (b, s2)

instance Functor     (StatefulFunc s) where {fmap f x = pure f <*> x}
instance Applicative (StatefulFunc s) where {pure = return; (<*>) = ap}

Programming with StateFunc Stack

pop'  :: StateFunc (Stack a) a
push' :: a -> State (Stack a) ()

pop'    = StateFunc $ \(a:as) -> (a,as)
push' a = StateFunc $ \as -> ((), a:as)

Now, let’s go back to our long sequence of stack operations.

testStack' :: StateFunc (Stack Int) Int
testStack' = do
  push' 0
  push' 1
  push' 2
  push' 3
  a <- pop'
  b <- pop'
  push' (a+b)
  pop'

> runStateFunc testStack' []
(5,[1,0])

Cool!

Helper Functions

get :: StateFunc s s                            -- get state out
get = StateFunc $ \s -> (s, s)

put :: s -> StateFunc s ()                      -- set "current" state
put s' = StateFunc $ \s -> ((), s')

modify :: (s -> s) -> StateFunc s ()            -- modify the state
modify f = StateFunc $ \s -> ((), f s)

evalStateFunc :: StateFunc s a -> s -> a        -- run and return final value
evalStateFunc sa s = fst $ runStateFunc sa s

execStateFunc :: StateFunc s a -> s -> s        -- run and return final state
execStateFunc sa s = snd $ runStateFunc sa s

If we want to, we can redefine pop' and push' using do-notation and the helpers for “reading” and “writing” the state (Stack) in the background.

push' a = do                -- push' a =
  as <- get                 --   get >>= \as ->
  put (a:as)                --   put (a:as)

pop' = do                   -- pop' =
  (a:as) <- get             --   get >>= \(a:as) ->
  put as                    --     put as >>
  return a                  --     return a

Source Files