llvm-hs Kaleidoscope Tutorial

1 Add AST and parsing

The first step in our compiler is to convert the source code into some sort of data structure that we can work with. This data structure usually ends up being a tree structure: nodes of expressions built up of other expressions. For example, (1 + (3 - 2)) would be a tree where 1, 3 and 2 are leaf nodes and +/- are parent nodes.

It represents the syntax of the program, but is abstract as it doesn't contain the details about it like whitespace or parenthesis: an abstract syntax tree, if you will.

In the original LLVM Kaleidoscope tutorial, this abstract syntax tree (AST) is usually built by first lexing the program into separate tokens like identifiers and keywords, and then parsing it to build up the tree structure. We're not going to do that in this tutorial, and instead opt for the Haskell-ier way with parser combinators.

Parser combinators allow us to lex and parse at the same time by simply specifying what we expect to parse. We'll be using the ReadP monad, which is also used for the Read class: In fact we'll just be able to parse our program by calling 'read'! The P in ReadP stands for precedence, and you'll see later on we'll be able to add some tricks to prefer certain patterns over others when parsing. We'll also be writing all our parsing with do notation, which I think you'll agree feels very natural to use.

AST.hs

1
module AST where
2
3
import Data.Char
4
import Text.Read 
5
import Text.ParserCombinators.ReadP hiding ((+++), choice)
6
7
data Expr = Num Float
8
          | Var String
9
          | BinOp BinOp Expr Expr
10
          | Call String [Expr]
11
  deriving Show
12
13
data BinOp = Add | Sub | Mul | Cmp Ordering
14
  deriving Show
15
16
instance Read Expr where
17
  readPrec = parens $ choice [ parseNum
18
                             , parseVar
19
                             , parseCall
20
                             , parseBinOp "<" 10 (Cmp LT)
21
                             , parseBinOp "+" 20 Add
22
                             , parseBinOp "-" 20 Sub
23
                             , parseBinOp "*" 40 Mul
24
                             ]
25
    where parseNum = Num <$> readPrec
26
          parseVar = Var <$> lift (munch1 isAlpha)
27
          parseBinOp s prc op = prec prc $ do
28
            a <- step readPrec
29
            lift $ do
30
              skipSpaces
31
              string s
32
              skipSpaces
33
            b <- readPrec
34
            return (BinOp op a b)
35
          parseCall = do
36
            func <- lift (munch1 isAlpha)
37
            params <- lift $ between (char '(') (char ')') $
38
                        sepBy (readS_to_P reads)
39
                              (skipSpaces >> char ',' >> skipSpaces)
40
            return (Call func params)
41
            
42
data Prototype = Prototype String [String]
43
  deriving Show
44
45
instance Read Prototype where
46
  readPrec = lift $ do
47
    name <- munch1 isAlpha
48
    params <- between (char '(') (char ')') $
49
                sepBy (munch1 isAlpha) skipSpaces
50
    return (Prototype name params)
51
52
data AST = Function Prototype Expr
53
         | Extern Prototype
54
         | TopLevelExpr Expr
55
  deriving Show
56
57
instance Read AST where
58
  readPrec = parseFunction +++ parseExtern +++ parseTopLevel
59
    where parseFunction = do
60
            lift $ string "def" >> skipSpaces
61
            Function <$> readPrec <*> readPrec
62
          parseExtern = do
63
            lift $ string "extern" >> skipSpaces
64
            Extern <$> readPrec
65
          parseTopLevel = TopLevelExpr <$> readPrec

Main.hs

1
main = pure ()
1
import AST
2
import System.IO
3
import Text.Read
4
main = do
5
  hPutStr stderr "ready> "
6
  ast <- (readMaybe <$> getLine) :: IO (Maybe AST)
7
  case ast of
8
    Just x -> hPrint stderr x
9
    Nothing ->  hPutStrLn stderr "Couldn't parse"
10
  main

1.1 Start parsing expressions

This starts off with defining not the AST, but just what expressions we want to be able to parse. So far we just handle numbers, and addition involving 2 numbers. We use the built-in ReadPrec parser combinators, which allow us to recursively parse the binary op of addition, and even allow us to reuse the Read instance on Int! prec and step are needed since without them, parseAdd will just get stuck repeatedly trying to parse the left expression (a).

AST.hs

1
module AST where
2
3
import Text.Read 
4
import Text.ParserCombinators.ReadP hiding ((+++))
5
6
data Expr = Num Float
7
          | Add Expr Expr
8
  deriving Show
9
10
instance Read Expr where
11
  readPrec = parseNum +++ parseAdd
12
    where parseNum = Num <$> readPrec
13
          -- use 'prec 1' and 'step' so that parsing 'a'
14
          -- can only go one step deep, to prevent ininfite
15
          -- recursion
16
          parseAdd = prec 1 $ do
17
            a <- step readPrec
18
            lift $ do
19
              skipSpaces
20
              char '+'
21
              skipSpaces
22
            b <- readPrec
23
            return (Add a b)

1.2 Parse more binary ops

We can parameterise the precedence that we parse so that * is parsed before +/- etc. Try it out with ghci AST.hs, and see the difference between *AST> read "1 * 2 - 3" and *AST> read "1 + 2 - 3"

AST.hs

1 1
module AST where
2 2
3 3
import Text.Read 
4
import Text.ParserCombinators.ReadP hiding ((+++))
4
import Text.ParserCombinators.ReadP hiding ((+++), choice)
5 5
6 6
data Expr = Num Float
7
          | Add Expr Expr
7
          | BinOp BinOp Expr Expr
8
  deriving Show
9
10
data BinOp = Add | Sub | Mul | Cmp Ordering
8 11
  deriving Show
9 12
10 13
instance Read Expr where
11
  readPrec = parseNum +++ parseAdd
14
  readPrec = choice [ parseNum
15
                    , parseBinOp "<" 10 (Cmp LT)
16
                    , parseBinOp "+" 20 Add
17
                    , parseBinOp "-" 20 Sub
18
                    , parseBinOp "*" 40 Mul
19
                    ]
12 20
    where parseNum = Num <$> readPrec
13 21
          -- use 'prec 1' and 'step' so that parsing 'a'
14 22
          -- can only go one step deep, to prevent ininfite
15 23
          -- recursion
16
          parseAdd = prec 1 $ do
24
          parseBinOp s prc op = prec prc $ do
17 25
            a <- step readPrec
18 26
            lift $ do
19 27
              skipSpaces
20
              char '+'
28
              string s
21 29
              skipSpaces
22 30
            b <- readPrec
23
            return (Add a b)
31
            return (BinOp op a b)
32
            

1.3 Parse variables in expressions

AST.hs

4 4
import Text.ParserCombinators.ReadP hiding ((+++), choice)
5 5
6 6
data Expr = Num Float
7
          | Var String
7 8
          | BinOp BinOp Expr Expr
8 9
  deriving Show
9 10

12 13
13 14
instance Read Expr where
14 15
  readPrec = choice [ parseNum
16
                    , parseVar
15 17
                    , parseBinOp "<" 10 (Cmp LT)
16 18
                    , parseBinOp "+" 20 Add
17 19
                    , parseBinOp "-" 20 Sub
18 20
                    , parseBinOp "*" 40 Mul
19 21
                    ]
20 22
    where parseNum = Num <$> readPrec
23
          parseVar = Var <$> lift (munch1 isAlpha)
21 24
          -- use 'prec 1' and 'step' so that parsing 'a'
22 25
          -- can only go one step deep, to prevent ininfite
23 26
          -- recursion

1.4 Parse defs and externs

They both share a common prototype so we can split that out into a separate definition. We also want to be able to parse top level expressions, so we add constructor for that in AST. Hopefully now you can see the beauty of monadic parsing: Function <$> readPrec <*> readPrec

AST.hs

1 1
module AST where
2 2
3
import Data.Char
3 4
import Text.Read 
4 5
import Text.ParserCombinators.ReadP hiding ((+++), choice)
5 6

21 22
                    ]
22 23
    where parseNum = Num <$> readPrec
23 24
          parseVar = Var <$> lift (munch1 isAlpha)
24
          -- use 'prec 1' and 'step' so that parsing 'a'
25
          -- can only go one step deep, to prevent ininfite
26
          -- recursion
27 25
          parseBinOp s prc op = prec prc $ do
28 26
            a <- step readPrec
29 27
            lift $ do

33 31
            b <- readPrec
34 32
            return (BinOp op a b)
35 33
            
34
data Prototype = Prototype String [String]
35
  deriving Show
36
37
instance Read Prototype where
38
  readPrec = lift $ do
39
    name <- munch1 isAlpha
40
    params <- between (char '(') (char ')') $
41
                sepBy (munch1 isAlpha) skipSpaces
42
    return (Prototype name params)
43
44
data AST = Function Prototype Expr
45
         | Extern Prototype
46
         | TopLevelExpr Expr
47
  deriving Show
48
49
instance Read AST where
50
  readPrec = parseFunction +++ parseExtern +++ parseTopLevel
51
    where parseFunction = do
52
            lift $ string "def" >> skipSpaces
53
            Function <$> readPrec <*> readPrec
54
          parseExtern = do
55
            lift $ string "extern" >> skipSpaces
56
            Extern <$> readPrec
57
          parseTopLevel = TopLevelExpr <$> readPrec

1.5 Parse parentheses

Yes, it's that easy. Try out the difference between 1 * 2 + 3 and 1 * (2 + 3)

AST.hs

13 13
  deriving Show
14 14
15 15
instance Read Expr where
16
  readPrec = choice [ parseNum
17
                    , parseVar
18
                    , parseBinOp "<" 10 (Cmp LT)
19
                    , parseBinOp "+" 20 Add
20
                    , parseBinOp "-" 20 Sub
21
                    , parseBinOp "*" 40 Mul
22
                    ]
16
  readPrec = parens $ choice [ parseNum
17
                             , parseVar
18
                             , parseBinOp "<" 10 (Cmp LT)
19
                             , parseBinOp "+" 20 Add
20
                             , parseBinOp "-" 20 Sub
21
                             , parseBinOp "*" 40 Mul
22
                             ]
23 23
    where parseNum = Num <$> readPrec
24 24
          parseVar = Var <$> lift (munch1 isAlpha)
25 25
          parseBinOp s prc op = prec prc $ do

1.6 Parse call expressions

The Call constructor stores the name of the callee and the list of arguments being passed to it.

AST.hs

7 7
data Expr = Num Float
8 8
          | Var String
9 9
          | BinOp BinOp Expr Expr
10
          | Call String [Expr]
10 11
  deriving Show
11 12
12 13
data BinOp = Add | Sub | Mul | Cmp Ordering

15 16
instance Read Expr where
16 17
  readPrec = parens $ choice [ parseNum
17 18
                             , parseVar
19
                             , parseCall
18 20
                             , parseBinOp "<" 10 (Cmp LT)
19 21
                             , parseBinOp "+" 20 Add
20 22
                             , parseBinOp "-" 20 Sub

30 32
              skipSpaces
31 33
            b <- readPrec
32 34
            return (BinOp op a b)
35
          parseCall = do
36
            func <- lift (munch1 isAlpha)
37
            params <- lift $ between (char '(') (char ')') $
38
                        sepBy (readS_to_P reads)
39
                              (skipSpaces >> char ',' >> skipSpaces)
40
            return (Call func params)
33 41
            
34 42
data Prototype = Prototype String [String]
35 43
  deriving Show

1.7 Add basic repl

For the moment this doesn't have any parsing errors other than "Couldn't parse". We'll go back to this later! Also note we're printing to stderr.

Main.hs

1
main = pure ()
1
import AST
2
import System.IO
3
import Text.Read
4
main = do
5
  hPutStr stderr "ready> "
6
  ast <- (readMaybe <$> getLine) :: IO (Maybe AST)
7
  case ast of
8
    Just x -> hPrint stderr x
9
    Nothing ->  hPutStrLn stderr "Couldn't parse"
10
  main

2 Add LLVM IR codegen

Now that we have our AST built up, its time to start thinking about semantics. And to think about semantics, we need to start building up code that does what our AST says.

In most compilers, we don't directly convert the AST right down to machine code: Usually there's an intermediate representation involved that's somewhere between our programming language and machine code. LLVM has an intermediate representation called LLVM IR, and that's what we'll be converting our AST to.

llvm-hs provides a monadic way of building up modules and functions, with ModuleBuilder and IRBuilder respectively. To generate our code we will traverse our AST inside these monads, spitting out LLVM IR as we go along.

AST.hs

4 4
import Text.Read 
5 5
import Text.ParserCombinators.ReadP hiding ((+++), choice)
6 6
7
data Expr = Num Float
7
data Expr = Num Double
8 8
          | Var String
9 9
          | BinOp BinOp Expr Expr
10 10
          | Call String [Expr]

18 18
                             , parseVar
19 19
                             , parseCall
20 20
                             , parseBinOp "<" 10 (Cmp LT)
21
                             , parseBinOp ">" 10 (Cmp GT)
22
                             , parseBinOp "==" 10 (Cmp EQ)
21 23
                             , parseBinOp "+" 20 Add
22 24
                             , parseBinOp "-" 20 Sub
23 25
                             , parseBinOp "*" 40 Mul

Main.hs

1
import AST
1
{-# LANGUAGE OverloadedStrings #-}
2
3
import AST as K -- K for Kaleidoscope
4
import Utils
5
import Control.Monad.Trans.Reader
6
import Control.Monad.IO.Class
7
import Data.String
8
import qualified Data.Map as Map
9
import qualified Data.Text.Lazy.IO as Text
10
import LLVM.AST.AddrSpace
11
import LLVM.AST.Constant
12
import LLVM.AST.Float
13
import LLVM.AST.FloatingPointPredicate hiding (False, True)
14
import LLVM.AST.Operand
15
import LLVM.AST.Type as Type
16
import LLVM.IRBuilder
17
import LLVM.Pretty
2 18
import System.IO
3
import Text.Read
4
main = do
5
  hPutStr stderr "ready> "
6
  ast <- (readMaybe <$> getLine) :: IO (Maybe AST)
7
  case ast of
8
    Just x -> hPrint stderr x
9
    Nothing ->  hPutStrLn stderr "Couldn't parse"
10
  main
19
import System.IO.Error
20
import Text.Read (readMaybe)
21
22
main :: IO ()
23
main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
24
25
repl :: ModuleBuilderT IO ()
26
repl = do
27
  liftIO $ hPutStr stderr "ready> "
28
  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
29
  case mline of
30
    Nothing -> return ()
31
    Just l -> do
32
      case readMaybe l of
33
        Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
34
        Just ast -> do
35
          hoist $ buildAST ast
36
          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
37
      repl
38
  where
39
    eofHandler e
40
      | isEOFError e = return Nothing
41
      | otherwise = ioError e
42
43
type Binds = Map.Map String Operand
44
45
buildAST :: AST -> ModuleBuilder Operand
46
buildAST (Function (Prototype nameStr paramStrs) body) = do
47
  let n = fromString nameStr
48
  function n params Type.double $ \ops -> do
49
    let binds = Map.fromList (zip paramStrs ops)
50
    flip runReaderT binds $ buildExpr body >>= ret
51
  where params = zip (repeat Type.double) (map fromString paramStrs)
52
53
buildAST (Extern (Prototype nameStr params)) =
54
  extern (fromString nameStr) (replicate (length params) Type.double) Type.double
55
56
buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
57
  const $ flip runReaderT mempty $ buildExpr x >>= ret
58
59
buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
60
buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
61
buildExpr (Var n) = do
62
  binds <- ask
63
  case binds Map.!? n of
64
    Just x -> pure x
65
    Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
66
67
buildExpr (BinOp op a b) = do
68
  opA <- buildExpr a
69
  opB <- buildExpr b
70
  tmp <- instr opA opB
71
  if isCmp
72
    then uitofp tmp Type.double
73
    else return tmp
74
  where isCmp
75
          | Cmp _ <- op = True
76
          | otherwise = False
77
        instr = case op of
78
                  K.Add -> fadd
79
                  K.Sub -> fsub
80
                  K.Mul -> fmul
81
                  K.Cmp LT -> fcmp OLT
82
                  K.Cmp GT -> fcmp OGT
83
                  K.Cmp EQ -> fcmp OEQ
84
85
buildExpr (Call callee params) = do
86
  paramOps <- mapM buildExpr params
87
  let nam = fromString callee
88
      -- get a pointer to the function
89
      typ = FunctionType Type.double (replicate (length params) Type.double) False
90
      ptrTyp = Type.PointerType typ (AddrSpace 0)
91
      ref = GlobalReference ptrTyp nam
92
  call (ConstantOperand ref) (zip paramOps (repeat []))

Utils.hs

1
{-|
2
Shoving away gross stuff into this one module.
3
-}
4
module Utils where
5
6
import Control.Monad.Trans.State
7
import Data.Functor.Identity
8
import LLVM.AST
9
import LLVM.IRBuilder.Module
10
import LLVM.IRBuilder.Internal.SnocList
11
12
mostRecentDef :: Monad m => ModuleBuilderT m Definition
13
mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get
14
15
hoist :: Monad m => ModuleBuilder a -> ModuleBuilderT m a
16
hoist m = ModuleBuilderT $ StateT $
17
  return . runIdentity . runStateT (unModuleBuilderT m)

2.1 Begin codegen

Now the fun begins. You should start by installing the llvm-hs packages. This tutorial will be keeping things "vanilla" by just installing the packages globally rather than using a .cabal file.

$ cabal new-install --lib llvm-hs llvm-hs-pure llvm-hs-pretty --write-ghc-environment-files=always

The above command should be enough to install them, assuming you already have LLVM installed correctly. (At the time of writing llvm-hs-pretty needs to be installed from source for llvm-8.0: https://github.com/llvm-hs/llvm-hs-pretty)

Like our parsing, our code generation is also monadic (this is a Haskell tutorial). There are two monads that we can use: IRBuilder and ModuleBuilder. The former is for LLVM IR instructions and the latter for function definitions, constants and the like.

The original tutorial puts everything into one module, so we are going to follow suit here. This makes things a bit hairy since now the repl needs to take place inside of a ModuleBuilderT IO. In an effort to keep our code as pure as possible, we've added a hoist function so that we can keep our codegen code inside ModuleBuilder.

There's also a helper function to grab the most recent definition so that we can print out similar to in the tutorial. This and hoist have been tucked away into Util.hs.

So far we only generate code for numbers: But this now paves the way for the rest of the code generation.

AST.hs

4 4
import Text.Read 
5 5
import Text.ParserCombinators.ReadP hiding ((+++), choice)
6 6
7
data Expr = Num Float
7
data Expr = Num Double
8 8
          | Var String
9 9
          | BinOp BinOp Expr Expr
10 10
          | Call String [Expr]

Main.hs

1
{-# LANGUAGE OverloadedStrings #-}
2
1 3
import AST
4
import Utils
5
import Control.Monad.IO.Class
6
import qualified Data.Text.Lazy.IO as Text
7
import LLVM.AST.Constant
8
import LLVM.AST.Float
9
import LLVM.AST.Operand
10
import LLVM.AST.Type as Type
11
import LLVM.IRBuilder
12
import LLVM.Pretty
2 13
import System.IO
3
import Text.Read
4
main = do
5
  hPutStr stderr "ready> "
6
  ast <- (readMaybe <$> getLine) :: IO (Maybe AST)
14
import Text.Read (readMaybe)
15
16
main = buildModuleT "main" repl
17
18
repl :: ModuleBuilderT IO ()
19
repl = do
20
  liftIO $ hPutStr stderr "ready> "
21
  ast <- liftIO $ readMaybe <$> getLine
7 22
  case ast of
8
    Just x -> hPrint stderr x
9
    Nothing ->  hPutStrLn stderr "Couldn't parse"
10
  main
23
    Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
24
    Just x -> do
25
      hoist $ buildAST x
26
      mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
27
  repl
28
  where 
29
30
buildAST :: AST -> ModuleBuilder Operand
31
buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
32
  const $ buildExpr x >>= ret
33
34
buildExpr :: Expr -> IRBuilderT ModuleBuilder Operand
35
buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))

Utils.hs

1
{-|
2
Shoving away gross stuff into this one module.
3
-}
4
module Utils where
5
6
import Control.Monad.Trans.State
7
import Data.Functor.Identity
8
import LLVM.AST
9
import LLVM.IRBuilder.Module
10
import LLVM.IRBuilder.Internal.SnocList
11
12
mostRecentDef :: Monad m => ModuleBuilderT m Definition
13
mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get
14
15
hoist :: Monad m => ModuleBuilder a -> ModuleBuilderT m a
16
hoist m = ModuleBuilderT $ StateT $
17
  return . runIdentity . runStateT (unModuleBuilderT m)

2.2 Generate code for binary operations

Also throw in the rest of the comparisons whilst we're at it.

AST.hs

18 18
                             , parseVar
19 19
                             , parseCall
20 20
                             , parseBinOp "<" 10 (Cmp LT)
21
                             , parseBinOp ">" 10 (Cmp GT)
22
                             , parseBinOp "==" 10 (Cmp EQ)
21 23
                             , parseBinOp "+" 20 Add
22 24
                             , parseBinOp "-" 20 Sub
23 25
                             , parseBinOp "*" 40 Mul

Main.hs

1 1
{-# LANGUAGE OverloadedStrings #-}
2 2
3
import AST
3
import AST as K -- K for Kaleidoscope
4 4
import Utils
5 5
import Control.Monad.IO.Class
6 6
import qualified Data.Text.Lazy.IO as Text
7 7
import LLVM.AST.Constant
8 8
import LLVM.AST.Float
9
import LLVM.AST.FloatingPointPredicate hiding (False, True)
9 10
import LLVM.AST.Operand
10 11
import LLVM.AST.Type as Type
11 12
import LLVM.IRBuilder

33 34
34 35
buildExpr :: Expr -> IRBuilderT ModuleBuilder Operand
35 36
buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
37
buildExpr (BinOp op a b) = do
38
  opA <- buildExpr a
39
  opB <- buildExpr b
40
  tmp <- instr opA opB
41
  if isCmp
42
    then uitofp tmp Type.double
43
    else return tmp
44
  where isCmp
45
          | Cmp _ <- op = True
46
          | otherwise = False
47
        instr = case op of
48
                  K.Add -> fadd
49
                  K.Sub -> fsub
50
                  K.Mul -> fmul
51
                  K.Cmp LT -> fcmp OLT
52
                  K.Cmp GT -> fcmp OGT
53
                  K.Cmp EQ -> fcmp OEQ

2.3 Generate code for functions and variables

Now that we're generating functions, variables can now be bound. This means we will have to somehow keep track of what variables are available in scope (i.e. passed to us in 'funciton'), so we've wrapped our IRBuilderT in a ReaderT (Map String Operand).

For now if the variable doesn't exist in scope, we just crash the program. But later on we will introduce a mechanism for errors, and tidy up the transformer stack.

Main.hs

2 2
3 3
import AST as K -- K for Kaleidoscope
4 4
import Utils
5
import Control.Monad.Trans.Reader
5 6
import Control.Monad.IO.Class
7
import Data.String
8
import qualified Data.Map as Map
6 9
import qualified Data.Text.Lazy.IO as Text
7 10
import LLVM.AST.Constant
8 11
import LLVM.AST.Float

28 31
  repl
29 32
  where 
30 33
34
type Binds = Map.Map String Operand
35
31 36
buildAST :: AST -> ModuleBuilder Operand
37
buildAST (Function (Prototype nameStr paramStrs) body) = do
38
  let n = fromString nameStr
39
  function n params Type.double $ \ops -> do
40
    let binds = Map.fromList (zip paramStrs ops)
41
    flip runReaderT binds $ buildExpr body >>= ret
42
  where params = zip (repeat Type.double) (map fromString paramStrs)
43
32 44
buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
33
  const $ buildExpr x >>= ret
45
  const $ flip runReaderT mempty $ buildExpr x >>= ret
34 46
35
buildExpr :: Expr -> IRBuilderT ModuleBuilder Operand
47
buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
36 48
buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
49
buildExpr (Var n) = do
50
  binds <- ask
51
  case binds Map.!? n of
52
    Just x -> pure x
53
    Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
54
37 55
buildExpr (BinOp op a b) = do
38 56
  opA <- buildExpr a
39 57
  opB <- buildExpr b

2.4 Generate code for externs

Main.hs

41 41
    flip runReaderT binds $ buildExpr body >>= ret
42 42
  where params = zip (repeat Type.double) (map fromString paramStrs)
43 43
44
buildAST (Extern (Prototype nameStr params)) =
45
  extern (fromString nameStr) (replicate (length params) Type.double) Type.double
46
44 47
buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
45 48
  const $ flip runReaderT mempty $ buildExpr x >>= ret
46 49

2.5 Handle EOFs in the repl

This way whenever ^D is typed at the repl, the program gracefully terminates and prints out the complete module.

Main.hs

15 15
import LLVM.IRBuilder
16 16
import LLVM.Pretty
17 17
import System.IO
18
import System.IO.Error
18 19
import Text.Read (readMaybe)
19 20
20
main = buildModuleT "main" repl
21
main :: IO ()
22
main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
21 23
22 24
repl :: ModuleBuilderT IO ()
23 25
repl = do
24 26
  liftIO $ hPutStr stderr "ready> "
25
  ast <- liftIO $ readMaybe <$> getLine
26
  case ast of
27
    Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
28
    Just x -> do
29
      hoist $ buildAST x
30
      mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
31
  repl
32
  where 
27
  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
28
  case mline of
29
    Nothing -> return ()
30
    Just l -> do
31
      case readMaybe l of
32
        Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
33
        Just ast -> do
34
          hoist $ buildAST ast
35
          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
36
      repl
37
  where
38
    eofHandler e
39
      | isEOFError e = return Nothing
40
      | otherwise = ioError e
33 41
34 42
type Binds = Map.Map String Operand
35 43

2.6 Generate code for call expressions

In order to call a function in LLVM IR, we need to grab a GlobalReference to it. Currently we just make it ourselves, inferring the type of it based on what arguments the caller provided and the fact that all functions return a double.

Main.hs

7 7
import Data.String
8 8
import qualified Data.Map as Map
9 9
import qualified Data.Text.Lazy.IO as Text
10
import LLVM.AST.AddrSpace
10 11
import LLVM.AST.Constant
11 12
import LLVM.AST.Float
12 13
import LLVM.AST.FloatingPointPredicate hiding (False, True)

80 81
                  K.Cmp LT -> fcmp OLT
81 82
                  K.Cmp GT -> fcmp OGT
82 83
                  K.Cmp EQ -> fcmp OEQ
84
85
buildExpr (Call callee params) = do
86
  paramOps <- mapM buildExpr params
87
  let nam = fromString callee
88
      -- get a pointer to the function
89
      typ = FunctionType Type.double (replicate (length params) Type.double) False
90
      ptrTyp = Type.PointerType typ (AddrSpace 0)
91
      ref = GlobalReference ptrTyp nam
92
  call (ConstantOperand ref) (zip paramOps (repeat []))

3 Add JIT

We have LLVM IR now, but our computers still can't run it. We could compile our code "offline" and write it to a file, but LLVM also provides frameworks for JITing: Just-in-time compilation. This is where the code is compiled just before it is run, and we will be using it make an interactive REPL.

Note that JITs are not the same as interpreters: An interpreter reads the program and directly computes the result. A JIT reads the program and generates more code for the computer to run, which then computes the result.

The current LLVM Kaleidoscope tutorial uses the old MC JIT framework: This tutorial will be using the fancy new OrcJIT framework. It's a bit more complicated but provides a lot more flexibility.

Main.hs

2 2
3 3
import AST as K -- K for Kaleidoscope
4 4
import Utils
5
import Control.Monad
6
import Control.Monad.Trans.Class
5 7
import Control.Monad.Trans.Reader
6 8
import Control.Monad.IO.Class
7 9
import Data.String
8 10
import qualified Data.Map as Map
9 11
import qualified Data.Text.Lazy.IO as Text
12
import Foreign.Ptr
10 13
import LLVM.AST.AddrSpace
11 14
import LLVM.AST.Constant
12 15
import LLVM.AST.Float
13 16
import LLVM.AST.FloatingPointPredicate hiding (False, True)
14 17
import LLVM.AST.Operand
15 18
import LLVM.AST.Type as Type
19
import LLVM.Context
16 20
import LLVM.IRBuilder
21
import LLVM.Module
22
import LLVM.OrcJIT
23
import LLVM.OrcJIT.CompileLayer
24
import LLVM.PassManager
17 25
import LLVM.Pretty
26
import LLVM.Target
18 27
import System.IO
19 28
import System.IO.Error
20 29
import Text.Read (readMaybe)
21 30
31
foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
32
33
data JITEnv = JITEnv
34
  { jitEnvContext :: Context
35
  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
36
  , jitEnvModuleKey :: ModuleKey
37
  }
38
22 39
main :: IO ()
23
main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
40
main =
41
  withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
42
    withExecutionSession $ \exSession ->
43
      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
44
        withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
45
          withIRCompileLayer linkingLayer tm $ \compLayer ->
46
            withModuleKey exSession $ \mdlKey -> do
47
              let env = JITEnv ctx compLayer mdlKey
48
              _ast <- runReaderT (buildModuleT "main" repl) env
49
              return ()
50
51
-- This can eventually be used to resolve external functions, e.g. a stdlib call
52
symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
53
symResolver sym = undefined
24 54
25
repl :: ModuleBuilderT IO ()
55
repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
26 56
repl = do
27 57
  liftIO $ hPutStr stderr "ready> "
28 58
  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler

32 62
      case readMaybe l of
33 63
        Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
34 64
        Just ast -> do
35
          hoist $ buildAST ast
36
          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
65
          anon <- isAnonExpr <$> hoist (buildAST ast)
66
          def <- mostRecentDef
67
          
68
          llvmAst <- moduleSoFar "main"
69
          ctx <- lift $ asks jitEnvContext
70
          env <- lift ask
71
          liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
72
            Text.hPutStrLn stderr $ ppll def
73
            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
74
            -- this returns true if the module was modified
75
            withPassManager spec $ flip runPassManager mdl
76
            when anon (jit env mdl >>= hPrint stderr)
77
78
          when anon (removeDef def)
37 79
      repl
38 80
  where
39 81
    eofHandler e
40 82
      | isEOFError e = return Nothing
41 83
      | otherwise = ioError e
84
    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
85
    isAnonExpr _ = False
86
87
jit :: JITEnv -> Module -> IO Double
88
jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
89
  withModule compLayer mdlKey mdl $ do
90
    mangled <- mangleSymbol compLayer "__anon_expr"
91
    Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
92
    mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
42 93
43 94
type Binds = Map.Map String Operand
44 95

Utils.hs

4 4
module Utils where
5 5
6 6
import Control.Monad.Trans.State
7
import Data.ByteString.Short (ShortByteString)
7 8
import Data.Functor.Identity
9
import Data.List
8 10
import LLVM.AST
9 11
import LLVM.IRBuilder.Module
10 12
import LLVM.IRBuilder.Internal.SnocList
11 13
14
moduleSoFar :: MonadModuleBuilder m => ShortByteString -> m Module
15
moduleSoFar nm = do
16
  s <- liftModuleState get
17
  let ds = getSnocList (builderDefs s)
18
  return $ defaultModule { moduleName = nm, moduleDefinitions = ds }
19
20
removeDef :: MonadModuleBuilder m => Definition -> m ()
21
removeDef def = liftModuleState (modify update)
22
  where
23
    update (ModuleBuilderState defs typeDefs) =
24
      let newDefs = SnocList (delete def (getSnocList defs))
25
      in ModuleBuilderState newDefs typeDefs
26
12 27
mostRecentDef :: Monad m => ModuleBuilderT m Definition
13 28
mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get
14 29

3.1 Set up the LLVM context and optimise the module

Now that we have some LLVM IR generated, we can run PassManager on our module to get a bunch of neat optimisations on it. Try it out with 3 + 2 to see some constant folding. Note that the original tutorial uses FunctionPassManager which optimises on a function per function basis: llvm-hs doesn't expose this yet (and this is all using the legacy pass manager anyway), so for now we just optimise the entire module at the end.

Main.hs

2 2
3 3
import AST as K -- K for Kaleidoscope
4 4
import Utils
5
import Control.Monad.Trans.Class
5 6
import Control.Monad.Trans.Reader
6 7
import Control.Monad.IO.Class
7 8
import Data.String

13 14
import LLVM.AST.FloatingPointPredicate hiding (False, True)
14 15
import LLVM.AST.Operand
15 16
import LLVM.AST.Type as Type
17
import LLVM.Context
16 18
import LLVM.IRBuilder
19
import LLVM.Module
20
import LLVM.PassManager
17 21
import LLVM.Pretty
22
import LLVM.Target
18 23
import System.IO
19 24
import System.IO.Error
20 25
import Text.Read (readMaybe)
21 26
22 27
main :: IO ()
23
main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
28
main = do
29
  withContext $ \ctx -> withHostTargetMachineDefault $ \tm -> do
30
    ast <- runReaderT (buildModuleT "main" repl) ctx
31
    return ()
24 32
25
repl :: ModuleBuilderT IO ()
33
repl :: ModuleBuilderT (ReaderT Context IO) ()
26 34
repl = do
27 35
  liftIO $ hPutStr stderr "ready> "
28 36
  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler

34 42
        Just ast -> do
35 43
          hoist $ buildAST ast
36 44
          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
45
46
          ast <- moduleSoFar "main"
47
          ctx <- lift ask
48
          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
49
            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
50
            -- this returns true if the module was modified
51
            withPassManager spec $ flip runPassManager mdl
52
            Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
37 53
      repl
38 54
  where
39 55
    eofHandler e

Utils.hs

4 4
module Utils where
5 5
6 6
import Control.Monad.Trans.State
7
import Data.ByteString.Short (ShortByteString)
7 8
import Data.Functor.Identity
8 9
import LLVM.AST
9 10
import LLVM.IRBuilder.Module
10 11
import LLVM.IRBuilder.Internal.SnocList
11 12
13
moduleSoFar :: MonadModuleBuilder m => ShortByteString -> m Module
14
moduleSoFar nm = do
15
  s <- liftModuleState get
16
  let ds = getSnocList (builderDefs s)
17
  return $ defaultModule { moduleName = nm, moduleDefinitions = ds }
18
12 19
mostRecentDef :: Monad m => ModuleBuilderT m Definition
13 20
mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get
14 21

3.2 Detect and remove _anonexprs in repl

You might have noticed that when entering in top-level expressions you end up getting duplicate _anonexpr.1, _anonexpr.2s etc. Since we don't want these anonymous functions to stick around, in this commit we're detecting them and then removing them once we're done.

We also only want to run the JIT whenever the user has entered a top-level expression, so we've also sketched that out.

Main.hs

2 2
3 3
import AST as K -- K for Kaleidoscope
4 4
import Utils
5
import Control.Monad
5 6
import Control.Monad.Trans.Class
6 7
import Control.Monad.Trans.Reader
7 8
import Control.Monad.IO.Class

17 18
import LLVM.Context
18 19
import LLVM.IRBuilder
19 20
import LLVM.Module
21
import LLVM.OrcJIT
22
import LLVM.OrcJIT.CompileLayer
20 23
import LLVM.PassManager
21 24
import LLVM.Pretty
22 25
import LLVM.Target

24 27
import System.IO.Error
25 28
import Text.Read (readMaybe)
26 29
30
data JITEnv = JITEnv
31
  { jitEnvContext :: Context
32
  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
33
  , jitEnvModuleKey :: ModuleKey
34
  }
35
27 36
main :: IO ()
28
main = do
29
  withContext $ \ctx -> withHostTargetMachineDefault $ \tm -> do
30
    ast <- runReaderT (buildModuleT "main" repl) ctx
31
    return ()
37
main =
38
  withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
39
    withExecutionSession $ \exSession ->
40
      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
41
        withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
42
          withIRCompileLayer linkingLayer tm $ \compLayer ->
43
            withModuleKey exSession $ \mdlKey -> do
44
              let env = JITEnv ctx compLayer mdlKey
45
              _ast <- runReaderT (buildModuleT "main" repl) env
46
              return ()
47
48
-- This can eventually be used to resolve external functions, e.g. a stdlib call
49
symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
50
symResolver sym = undefined
32 51
33
repl :: ModuleBuilderT (ReaderT Context IO) ()
52
repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
34 53
repl = do
35 54
  liftIO $ hPutStr stderr "ready> "
36 55
  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler

40 59
      case readMaybe l of
41 60
        Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
42 61
        Just ast -> do
43
          hoist $ buildAST ast
44
          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
45
46
          ast <- moduleSoFar "main"
47
          ctx <- lift ask
48
          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
62
          anon <- isAnonExpr <$> hoist (buildAST ast)
63
          def <- mostRecentDef
64
          
65
          llvmAst <- moduleSoFar "main"
66
          ctx <- lift $ asks jitEnvContext
67
          env <- lift ask
68
          liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
69
            Text.hPutStrLn stderr $ ppll def
49 70
            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
50 71
            -- this returns true if the module was modified
51 72
            withPassManager spec $ flip runPassManager mdl
52
            Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
73
            when anon (jit env mdl >>= hPrint stderr)
74
75
          when anon (removeDef def)
53 76
      repl
54 77
  where
55 78
    eofHandler e
56 79
      | isEOFError e = return Nothing
57 80
      | otherwise = ioError e
81
    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
82
    isAnonExpr _ = False
83
84
jit :: JITEnv -> Module -> IO Double
85
jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
86
  withModule compLayer mdlKey mdl $
87
    return 0
58 88
59 89
type Binds = Map.Map String Operand
60 90

Utils.hs

6 6
import Control.Monad.Trans.State
7 7
import Data.ByteString.Short (ShortByteString)
8 8
import Data.Functor.Identity
9
import Data.List
9 10
import LLVM.AST
10 11
import LLVM.IRBuilder.Module
11 12
import LLVM.IRBuilder.Internal.SnocList

16 17
  let ds = getSnocList (builderDefs s)
17 18
  return $ defaultModule { moduleName = nm, moduleDefinitions = ds }
18 19
20
removeDef :: MonadModuleBuilder m => Definition -> m ()
21
removeDef def = liftModuleState (modify update)
22
  where
23
    update (ModuleBuilderState defs typeDefs) =
24
      let newDefs = SnocList (delete def (getSnocList defs))
25
      in ModuleBuilderState newDefs typeDefs
26
19 27
mostRecentDef :: Monad m => ModuleBuilderT m Definition
20 28
mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get
21 29

3.3 Find our JIT'ed function and run it

Here is where the magic happens. We first mangle the symbol before passing it to findSymbolIn: The compilation then happens here. It will spit back a FFI pointer, which we then need to invoke ourselves.

In order to do this we have to specify the Haskell type of the function that we're pointing to, which resides in C-land, hence the ccall attribute in our foreign declaration.

Try it out!

Main.hs

9 9
import Data.String
10 10
import qualified Data.Map as Map
11 11
import qualified Data.Text.Lazy.IO as Text
12
import Foreign.Ptr
12 13
import LLVM.AST.AddrSpace
13 14
import LLVM.AST.Constant
14 15
import LLVM.AST.Float

27 28
import System.IO.Error
28 29
import Text.Read (readMaybe)
29 30
31
foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
32
30 33
data JITEnv = JITEnv
31 34
  { jitEnvContext :: Context
32 35
  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer

83 86
84 87
jit :: JITEnv -> Module -> IO Double
85 88
jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
86
  withModule compLayer mdlKey mdl $
87
    return 0
89
  withModule compLayer mdlKey mdl $ do
90
    mangled <- mangleSymbol compLayer "__anon_expr"
91
    Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
92
    mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
88 93
89 94
type Binds = Map.Map String Operand
90 95

4 Add control flow

So far our language can parse and evaluate floating point expressions. However each expression and it's subexpressions always gets evaluated. To do more complex computations, we want to be able to control what gets evaluated and what doesn't.

This is known as control flow, and the two most famous imperative constructs are probably the if statement and for loop. We will add them to our language in this chapter, and finally put our boolean expressions (42 < i) to good use.

AST.hs

2 2
3 3
import Data.Char
4 4
import Text.Read 
5
import Text.ParserCombinators.ReadP hiding ((+++), choice)
5
import Text.ParserCombinators.ReadP hiding ((+++), (<++), choice)
6 6
7 7
data Expr = Num Double
8 8
          | Var String
9 9
          | BinOp BinOp Expr Expr
10 10
          | Call String [Expr]
11
          | If Expr Expr Expr
12
          | For String Expr Expr (Maybe Expr) Expr
11 13
  deriving Show
12 14
13 15
data BinOp = Add | Sub | Mul | Cmp Ordering

17 19
  readPrec = parens $ choice [ parseNum
18 20
                             , parseVar
19 21
                             , parseCall
22
                             , parseIf
23
                             , parseFor
20 24
                             , parseBinOp "<" 10 (Cmp LT)
21 25
                             , parseBinOp ">" 10 (Cmp GT)
22 26
                             , parseBinOp "==" 10 (Cmp EQ)

28 32
          parseVar = Var <$> lift (munch1 isAlpha)
29 33
          parseBinOp s prc op = prec prc $ do
30 34
            a <- step readPrec
31
            lift $ do
32
              skipSpaces
33
              string s
34
              skipSpaces
35
            spaced $ string s
35 36
            b <- readPrec
36 37
            return (BinOp op a b)
37 38
          parseCall = do

40 41
                        sepBy (readS_to_P reads)
41 42
                              (skipSpaces >> char ',' >> skipSpaces)
42 43
            return (Call func params)
44
          parseIf = do
45
            spaced $ string "if" 
46
            cond <- readPrec
47
            spaced $ string "then"
48
            thenE <- readPrec
49
            spaced $ string "else"
50
            elseE <- readPrec
51
            return (If cond thenE elseE)
52
          parseFor = do
53
            spaced $ string "for"
54
            identifier <- lift (munch1 isAlpha)
55
            spaced $ char '='
56
            start <- readPrec
57
            spaced $ char ','
58
            cond <- readPrec
59
            stp <- (spaced (char ',') >> Just <$> step readPrec)
60
                     <++ pure Nothing
61
            spaced $ string "in"
62
            body <- readPrec
63
            return (For identifier start cond stp body)
64
          spaced f = lift $ skipSpaces >> f >> skipSpaces
43 65
            
44 66
data Prototype = Prototype String [String]
45 67
  deriving Show

Main.hs

1 1
{-# LANGUAGE OverloadedStrings #-}
2
{-# LANGUAGE RecursiveDo #-}
2 3
3 4
import AST as K -- K for Kaleidoscope
4 5
import Utils

141 142
      ptrTyp = Type.PointerType typ (AddrSpace 0)
142 143
      ref = GlobalReference ptrTyp nam
143 144
  call (ConstantOperand ref) (zip paramOps (repeat []))
145
146
buildExpr (If cond thenE elseE) = mdo
147
  _ifB <- block `named` "if"
148
149
  -- since everything is a double, false == 0
150
  let zero = ConstantOperand (Float (Double 0))
151
  condV <- buildExpr cond
152
  cmp <- fcmp ONE zero condV `named` "cmp"
153
154
  condBr cmp thenB elseB
155
156
  thenB <- block `named` "then"
157
  thenOp <- buildExpr thenE
158
  br mergeB
159
160
  elseB <- block `named` "else"
161
  elseOp <- buildExpr elseE
162
  br mergeB
163
164
  mergeB <- block `named` "ifcont"
165
  phi [(thenOp, thenB), (elseOp, elseB)]
166
167
buildExpr (For name init cond mStep body) = mdo
168
  preheaderB <- block `named` "preheader"
169
170
  initV <- buildExpr init `named` "init"
171
  
172
  -- build the condition expression with 'i' in the bindings
173
  initCondV <- withReaderT (Map.insert name initV) $
174
                (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
175
176
  -- skip the loop if we don't meet the condition with the init
177
  condBr initCondV loopB afterB
178
179
  loopB <- block `named` "loop"
180
  i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
181
182
  -- build the body expression with 'i' in the bindings
183
  withReaderT (Map.insert name i) $ buildExpr body `named` "body"
184
185
  -- default to 1 if there's no step defined
186
  stepV <- case mStep of
187
    Just step -> buildExpr step
188
    Nothing -> return $ ConstantOperand (Float (Double 1))
189
190
  nextVar <- fadd i stepV `named` "nextvar"
191
192
  let zero = ConstantOperand (Float (Double 0))
193
  -- again we need 'i' in the bindings
194
  condV <- withReaderT (Map.insert name i) $
195
            (buildExpr cond >>= fcmp ONE zero) `named` "cond"
196
  condBr condV loopB afterB
197
198
  afterB <- block `named` "after"
199
  -- since a for loop doesn't really have a value, return 0
200
  return $ ConstantOperand (Float (Double 0))
201

4.1 Parse if expressions

Back to the AST.

AST.hs

8 8
          | Var String
9 9
          | BinOp BinOp Expr Expr
10 10
          | Call String [Expr]
11
          | If Expr Expr Expr
11 12
  deriving Show
12 13
13 14
data BinOp = Add | Sub | Mul | Cmp Ordering

17 18
  readPrec = parens $ choice [ parseNum
18 19
                             , parseVar
19 20
                             , parseCall
21
                             , parseIf
20 22
                             , parseBinOp "<" 10 (Cmp LT)
21 23
                             , parseBinOp ">" 10 (Cmp GT)
22 24
                             , parseBinOp "==" 10 (Cmp EQ)

40 42
                        sepBy (readS_to_P reads)
41 43
                              (skipSpaces >> char ',' >> skipSpaces)
42 44
            return (Call func params)
45
          parseIf = do
46
            lift $ skipSpaces >> string "if" >> skipSpaces
47
            cond <- readPrec
48
            lift $ skipSpaces >> string "then" >> skipSpaces
49
            thenE <- readPrec
50
            lift $ skipSpaces >> string "else" >> skipSpaces
51
            elseE <- readPrec
52
            return (If cond thenE elseE)
43 53
            
44 54
data Prototype = Prototype String [String]
45 55
  deriving Show

4.2 Tidy up the parsing a bit

The Text.ParserCombinators.ReadPrec/Text.ParserCominators.ReadP split adds a bit of lifting cruft, so lets try to sweep that away.

AST.hs

30 30
          parseVar = Var <$> lift (munch1 isAlpha)
31 31
          parseBinOp s prc op = prec prc $ do
32 32
            a <- step readPrec
33
            lift $ do
34
              skipSpaces
35
              string s
36
              skipSpaces
33
            spaced $ string s
37 34
            b <- readPrec
38 35
            return (BinOp op a b)
39 36
          parseCall = do

43 40
                              (skipSpaces >> char ',' >> skipSpaces)
44 41
            return (Call func params)
45 42
          parseIf = do
46
            lift $ skipSpaces >> string "if" >> skipSpaces
43
            spaced $ string "if" 
47 44
            cond <- readPrec
48
            lift $ skipSpaces >> string "then" >> skipSpaces
45
            spaced $ string "then"
49 46
            thenE <- readPrec
50
            lift $ skipSpaces >> string "else" >> skipSpaces
47
            spaced $ string "else"
51 48
            elseE <- readPrec
52 49
            return (If cond thenE elseE)
50
          spaced f = lift $ skipSpaces >> f >> skipSpaces
53 51
            
54 52
data Prototype = Prototype String [String]
55 53
  deriving Show

4.3 Generate code for if statements

Hopefully you found this addition straightforward and enjoyable, now that we have all the infrastructure in place. Let's walk through this step by step.

We first start off our loop by creating a basic block called "if". A basic block is a chunk of code that has a single entry and exit point, and in this block we're putting in the condition expression. Since everything in Kaleidoscope is a double (yuck), we need to compare it to zero to get a boolean that we can actually branch with.

After branching we create two other basic blocks labelled "then" and "else", containing the then and else expressions of the branch respectively. They both branch back to the final basic block, "ifcont".

Now since LLVM IR is in SSA (single static assignment) form, the then and else branches both write to separate variables: We can't write to the same variable twice. In order to actually differentiate between the then expression and the else expression, we use a phi node in the merge block. Depending on what basic block it arrived from, it will use either one of two operands that we've passed to it.

In this case, we use it to return the then expression if the then branch was taken, or the else expression if the else branch was taken. Phi nodes are very common in LLVM IR, and open up lots of optimisations by allowing us to stay in SSA form.

You might have also noticed that unlike the C++ tutorial, we didn't need to rewind our builder back and generate the "mergeB" block first. In fact, we actually end up referencing mergeB before it is even declared! How can this be possible? Through recursive do! The MonadFix instance on IRBuilder allow us to exploit laziness and use mdo to refer to variables that haven't yet been evaluated.

Main.hs

1 1
{-# LANGUAGE OverloadedStrings #-}
2
{-# LANGUAGE RecursiveDo #-}
2 3
3 4
import AST as K -- K for Kaleidoscope
4 5
import Utils

141 142
      ptrTyp = Type.PointerType typ (AddrSpace 0)
142 143
      ref = GlobalReference ptrTyp nam
143 144
  call (ConstantOperand ref) (zip paramOps (repeat []))
145
146
buildExpr (If cond thenE elseE) = mdo
147
  _ifB <- block `named` "if"
148
149
  -- since everything is a double, false == 0
150
  let zero = ConstantOperand (Float (Double 0))
151
  condV <- buildExpr cond
152
  cmp <- fcmp ONE zero condV `named` "cmp"
153
154
  condBr cmp thenB elseB
155
156
  thenB <- block `named` "then"
157
  thenOp <- buildExpr thenE
158
  br mergeB
159
160
  elseB <- block `named` "else"
161
  elseOp <- buildExpr elseE
162
  br mergeB
163
164
  mergeB <- block `named` "ifcont"
165
  phi [(thenOp, thenB), (elseOp, elseB)]

4.4 Parse for loops

Watch out for the <++: this tries to parse whatever is on the left first, and then only if it fails tries whatever is on the right. This is different from <|>, which will include both in all the possible final parsings.

AST.hs

2 2
3 3
import Data.Char
4 4
import Text.Read 
5
import Text.ParserCombinators.ReadP hiding ((+++), choice)
5
import Text.ParserCombinators.ReadP hiding ((+++), (<++), choice)
6 6
7 7
data Expr = Num Double
8 8
          | Var String
9 9
          | BinOp BinOp Expr Expr
10 10
          | Call String [Expr]
11 11
          | If Expr Expr Expr
12
          | For String Expr Expr (Maybe Expr) Expr
12 13
  deriving Show
13 14
14 15
data BinOp = Add | Sub | Mul | Cmp Ordering

19 20
                             , parseVar
20 21
                             , parseCall
21 22
                             , parseIf
23
                             , parseFor
22 24
                             , parseBinOp "<" 10 (Cmp LT)
23 25
                             , parseBinOp ">" 10 (Cmp GT)
24 26
                             , parseBinOp "==" 10 (Cmp EQ)

47 49
            spaced $ string "else"
48 50
            elseE <- readPrec
49 51
            return (If cond thenE elseE)
52
          parseFor = do
53
            spaced $ string "for"
54
            identifier <- lift (munch1 isAlpha)
55
            spaced $ char '='
56
            start <- readPrec
57
            spaced $ char ','
58
            cond <- readPrec
59
            stp <- (spaced (char ',') >> Just <$> step readPrec)
60
                     <++ pure Nothing
61
            spaced $ string "in"
62
            body <- readPrec
63
            return (For identifier start cond stp body)
50 64
          spaced f = lift $ skipSpaces >> f >> skipSpaces
51 65
            
52 66
data Prototype = Prototype String [String]

4.5 Generate code for for loops

Our for loops are like C style for loops, with an initial value, condition and step.

We begin generating code by setting up the loop in the preheader. Here we generate the initial value and compare it to the condition, finishing early if we need to.

The main loop begins with choosing either the initial value, or the incremented value for the loop variable (usually called the induction variable). We can then put this inside the binding environment whenever we generate the body expression with the runReaderT function: This lets the body access the induction variable with the associated name, without having to change our Reader transformer into a State transformer.

Once we've generated the body, we just need to increment the induction variable by the step, and check if the condition is met. Then depending on the result of that we either branch back up to the top of the loop, or to the after block to exit.

You might have noticed that we don't have any side-effects at the moment, so you can't tell what went on inside a loop - it always evaluates to zero. We're going to add something to do inside these loops next.

Main.hs

163 163
164 164
  mergeB <- block `named` "ifcont"
165 165
  phi [(thenOp, thenB), (elseOp, elseB)]
166
167
buildExpr (For name init cond mStep body) = mdo
168
  preheaderB <- block `named` "preheader"
169
170
  initV <- buildExpr init `named` "init"
171
  
172
  -- build the condition expression with 'i' in the bindings
173
  initCondV <- withReaderT (Map.insert name initV) $
174
                (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
175
176
  -- skip the loop if we don't meet the condition with the init
177
  condBr initCondV loopB afterB
178
179
  loopB <- block `named` "loop"
180
  i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
181
182
  -- build the body expression with 'i' in the bindings
183
  withReaderT (Map.insert name i) $ buildExpr body `named` "body"
184
185
  -- default to 1 if there's no step defined
186
  stepV <- case mStep of
187
    Just step -> buildExpr step
188
    Nothing -> return $ ConstantOperand (Float (Double 1))
189
190
  nextVar <- fadd i stepV `named` "nextvar"
191
192
  let zero = ConstantOperand (Float (Double 0))
193
  -- again we need 'i' in the bindings
194
  condV <- withReaderT (Map.insert name i) $
195
            (buildExpr cond >>= fcmp ONE zero) `named` "cond"
196
  condBr condV loopB afterB
197
198
  afterB <- block `named` "after"
199
  -- since a for loop doesn't really have a value, return 0
200
  return $ ConstantOperand (Float (Double 0))
201

5 Add standard library

We want to be able to call a function inside our for loop that has some sort of side effect that we can observe: How about printing to stdout? Our language isn't quite sophisticated enough yet to handle writing to files and handles or make system calls, but we can work around this by writing this code in another language and calling it from Kaleidoscope.

In this chapter we are going to write our standard library in C, but feel free to experiment with compiling and linking other languages: Try making raw system calls in assembly, or even writing it in Haskell!

The standard library contains one function, putchard, which takes in a double and prints a character to stdout. It's the equivalent of putchar(3), but since Kaleidoscope only works with doubles we need to cast it first.

Once you have finished this chapter, you should now be able to run the following code to print out a bunch of characters:

ready> extern putchard(x)

declare external ccc double @putchard(double)

ready> for i = 0, i < 10 in putchard(42+i)

define external ccc double @_anonexpr() {

preheader_0:

%initcond_0 = fcmp olt double 0.000000e0, 1.000000e1

%initcond1 = uitofp i1 %initcond0 to double

%initcond2 = fcmp one double 0.000000e0, %initcond1

br i1 %initcond2, label %loop0, label %after_0

loop_0:

%i0 = phi double [0.000000e0, %preheader0], [%nextvar0, %loop0]

%body0 = fadd double 4.200000e1, %i0

%body1 = call ccc double @putchard(double %body0)

%nextvar0 = fadd double %i0, 1.000000e0

%cond0 = fcmp olt double %i0, 1.000000e1

%cond1 = uitofp i1 %cond0 to double

%cond2 = fcmp one double 0.000000e0, %cond1

br i1 %cond2, label %loop0, label %after_0

after_0:

ret double 0.000000e0

}

Resolving MangledSymbol "_putchard" to 0x10574c2d0

*+,-./01234

.gitignore

1
stdlib.dylib
2
stdlib.o
3
Main

Main.hs

19 19
import LLVM.AST.Type as Type
20 20
import LLVM.Context
21 21
import LLVM.IRBuilder
22
import LLVM.Linking
22 23
import LLVM.Module
23 24
import LLVM.OrcJIT
24 25
import LLVM.OrcJIT.CompileLayer
25 26
import LLVM.PassManager
26 27
import LLVM.Pretty
27 28
import LLVM.Target
29
import Numeric
28 30
import System.IO
29 31
import System.IO.Error
30 32
import Text.Read (readMaybe)

38 40
  }
39 41
40 42
main :: IO ()
41
main =
43
main = do
44
  loadLibraryPermanently (Just "stdlib.dylib")
42 45
  withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
43 46
    withExecutionSession $ \exSession ->
44 47
      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->

51 54
52 55
-- This can eventually be used to resolve external functions, e.g. a stdlib call
53 56
symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
54
symResolver sym = undefined
57
symResolver sym = do
58
  ptr <- getSymbolAddressInProcess sym
59
  putStrLn $ "Resolving " <> show sym <> " to 0x" <> showHex ptr ""
60
  return (Right (JITSymbol ptr defaultJITSymbolFlags))
55 61
56 62
repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
57 63
repl = do

Makefile

1
stdlib.dylib: stdlib.c
2
	clang -shared $< -o $@
3
4
# for statically linking the stdlib:
5
# make sure to change in Main.hs
6
#   loadLibraryPermanently (Just "stdlib.dylib")
7
# to
8
#   loadLibraryPermanently Nothing
9
stdlib.o: stdlib.c
10
	clang -c $< -o $@
11
12
Main: Main.hs stdlib.o
13
	ghc $^ -o $@ -optl -Wl,-exported_symbols_list,stdlib.syms \
14
		-no-keep-hi-files -no-keep-o-files

stdlib.c

1
#include <stdio.h>
2
// Takes a double and writes it to stdout
3
double putchard(double x) {
4
	int res = putchar((int)x);
5
	fflush(stdout);
6
	return (double)res;
7
}

stdlib.syms

1
_putchard

5.1 Start the standard library

Our standard library so far is just going to contain a single function that acts as a version of putchar(3), except it operates on doubles. We'll also flush stdout so that we can immediately see what we've printed when we're inside the repl.

stdlib.c

1
#include <stdio.h>
2
// Takes a double and writes it to stdout
3
double putchard(double x) {
4
	int res = putchar((int)x);
5
	fflush(stdout);
6
	return (double)res;
7
}

5.2 Add Makefile for building stdlib

There are two ways we can go about including our standard library in our language. The first method is via statically linking it into our final executable, by compiling our Haskell code and standard library object file with it. The second method is to compile the standard library into a shared library and dynamically load it at runtime.

Here we are going to do the latter, since it means we can still run our code with runghc, at the cost of needing to locate the path to the library inside Main.hs. (I'll include an example later on of how to include it the static way)

macOS produces dylibs - Linux typically uses .so files. You may want to rename the rule to align with whatever operating system you are targeting.

.gitignore

1
stdlib.dylib

Makefile

1
stdlib.dylib: stdlib.c
2
	clang -shared $< -o $@

5.3 Load the standard library

This is equivalent to dlopen(3), where it will load all the code and symbols from our standard library into our process. We need to do this so that later on we can call standard library functions whilst JITing.

If you went down the statically linked route, you still need to call this with Nothing to expose the symbols to LLVM.

Main.hs

19 19
import LLVM.AST.Type as Type
20 20
import LLVM.Context
21 21
import LLVM.IRBuilder
22
import LLVM.Linking
22 23
import LLVM.Module
23 24
import LLVM.OrcJIT
24 25
import LLVM.OrcJIT.CompileLayer

38 39
  }
39 40
40 41
main :: IO ()
41
main =
42
main = do
43
  loadLibraryPermanently (Just "stdlib.dylib")
42 44
  withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
43 45
    withExecutionSession $ \exSession ->
44 46
      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->

5.4 Resolve symbols

Now that the standard library is loaded into our process, we can finally being to fill out our symbol resolver. The symbol resolver takes in a symbol, and returns a pointer wrapped inside a JITSymbol if it can find the address of it. Otherwise it returns an error.

You can use the symbol resolver to do much fancier stuff in OrcJIT, but for now we are just going to look up symbols that we have previously loaded into our process via loadLibraryPermanently.

Main.hs

26 26
import LLVM.PassManager
27 27
import LLVM.Pretty
28 28
import LLVM.Target
29
import Numeric
29 30
import System.IO
30 31
import System.IO.Error
31 32
import Text.Read (readMaybe)

53 54
54 55
-- This can eventually be used to resolve external functions, e.g. a stdlib call
55 56
symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
56
symResolver sym = undefined
57
symResolver sym = do
58
  ptr <- getSymbolAddressInProcess sym
59
  putStrLn $ "Resolving " <> show sym <> " to 0x" <> showHex ptr ""
60
  return (Right (JITSymbol ptr defaultJITSymbolFlags))
57 61
58 62
repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
59 63
repl = do

5.5 Add rules for statically linking the standard library in

If statically linking in the Makefile, we need to make sure that our putchard symbol/function doesn't get stripped out, as GHC passes the -dead_strip flag to the linker to remove any unused functions. Since nothing directly calls putchard in our program, it will get stripped out unless we specify that we want to export it. We can use an export list to indicate that we still want the symbol. (On Linux, you might need to remove the underscore from _putchard. )

.gitignore

1 1
stdlib.dylib
2
stdlib.o
3
Main

Makefile

1 1
stdlib.dylib: stdlib.c
2 2
	clang -shared $< -o $@
3
4
# for statically linking the stdlib:
5
# make sure to change in Main.hs
6
#   loadLibraryPermanently (Just "stdlib.dylib")
7
# to
8
#   loadLibraryPermanently Nothing
9
stdlib.o: stdlib.c
10
	clang -c $< -o $@
11
12
Main: Main.hs stdlib.o
13
	ghc $^ -o $@ -optl -Wl,-exported_symbols_list,stdlib.syms \
14
		-no-keep-hi-files -no-keep-o-files

stdlib.syms

1
_putchard