Free Tutorial (Originally for Reid Draper)
Jun 16, 2014
From nothing to a simple DSL using Free Monads. First, some preliminaries:
First we need to define the constructs of our DSL. Normally we would define our language as a ‘vanilla’ recursive type. Something like:
data DSL = Print String DSL
| GetChar (Char -> DSL)
| Halt
This time we’ll define factoring out the recursion as a type parameter.
data DSLf param = Print String param
| GetLine (String -> param)
| GetChar (Error -> param) (Char -> param)
| Halt
So our language can print a string, get a char and Halt. Simple enough.
Let’s show that our data-type is a Functor
instance Functor DSLf where
fmap f Halt = Halt
fmap f (Print str cont) = Print str (f cont)
fmap f (GetLine next) = GetLine (f . next)
fmap f (GetChar k1 k2) = GetChar (f . k1) (f . k2)
Now let’s get back the data-type we would have gotten if we’d defined our language in the ‘vanilla’ style mentioned above. For this we need another data-type that can provide the explicit recursion
Free takes two types, f and a. f must have kind * -> * and a kind *. Free is providing the same utility that fix does for functions, creating a recursive version of a type that was not recursive.
If we can guarantee that f is a functor, we can make any instantiation of Free f a
into a monad
instance Functor f => Monad (Free f) where
return = Pure
Pure a >>= f = f a
Free x >>= f = Free (fmap (>>= f) x)
It’s okay if you don’t see the utility of this monad instance yet, later we’ll walk through how it works.
Now we have our ‘proper’ DSL.
At this point we could call it a day and start using our DSL. The issue is that it would be ugly. We’d have to write programs in a very ugly style. Hello world would look like this:
hw :: DSL ()
hw = Free (Print "hello\n" (Free Halt))
Luckily we can abstract away the need to wrap everything in ’Free’s and ’Pure’s
Now we can provide user-friendly functions for each of the constructs in our DSL.
By the definition of liftFree, this gives us:
pr str = liftFree $ Print str ()
pr str = Free (fmap Pure (Print str ()))
pr str = Free (Print str (Pure ())) -- Ignoring laziness (which is safe in this case)
The rest follow the same pattern and won’t be walked through
gt :: DSL String
gt = liftFree $ GetLine id
getC :: DSL (Either Error Char)
getC = liftFree $ GetChar (Left . id) (Right . id)
hlt :: DSL a
hlt = liftFree Halt
And because DSL a
is a monad (via the monad instance for Free) we can write programs using the do notation
Here is the kicker,
t1
is actually a data structure!
The line pr "hello\n"
taken alone would look like:
Free (Print "hello\n" (Pure ()))
The whole things desugared out of do notation would look like:
1) Free (Print "hello\n" (Pure ())) >>
2) Free (GetLine (Pure . id)) >>= (\str ->
3) Free (Print ("Hello again " ++ str) (Pure ())))
Because of laziness, this structure stays as it is until something tries to inspect the first constructor. It looks like the first construct is the Free from line 1, but really the topmost expression is the function (>>)
So if you try to inspect the first constructor, you force the application.
Let’s give lines 2 + 3 the name ‘rest’ for now.
Using the definition of >>
for Free on line 1 we get:
Free (fmap (>>= (\_ -> rest)) (Print "hello\n" (Pure ())))
If you were needing the result of a function like ‘isFree’ then this is where reduction would stop.
If you were trying to interpret the DSL, then we would need to know which of the DSLf constructors was first inside the Free we’ve revealed. This forces the evaluation of the fmap (>>=
…
So, using the definition of fmap on DSLfs we get:
Free (Print "hello\n ((>>= (\_ -> rest)) (Pure ())))
dummyStr = "This is a dummy value"
dummyChr = Right 'k'
dummyChr2 = Left "Shh"
interpDumb :: DSL a -> [String]
interpDumb (Pure a) = []
interpDumb (Free Halt) = []
interpDumb (Free (Print s n)) = s : interpDumb n
interpDumb (Free (GetLine k)) = interpDumb (k dummyStr)
interpDumb (Free (GetChar k1 k2)) = interpDumb $ either k1 k2 dummyChr2
t2 = do
str <- gt
pr $ "Hello again " ++ (take 2 str)
ch <- getC
case ch of
Left e -> pr e
Right c1 -> pr $ take 3 $ repeat c1
interpIO :: DSL a -> IO a
interpIO (Pure a) = return a
interpIO (Free Halt) = exitSuccess
interpIO (Free (Print s n)) = print s >> interpIO n
interpIO (Free (GetLine k)) = getLine >>= interpIO . k
interpIO (Free (GetChar k1 k2)) = do
ch <- getChar
interpIO $ either k1 k2 $ if ch == '\EOT'
then Left "Error"
else Right ch