osa1 github about atom

Knot-tying: two more examples, and an alternative

February 27, 2020 - Tagged as: en, haskell, ghc.

In the previous post we’ve looked at a representation of expressions in a programming language, what the representation makes easy and where we have to use knot-tying.

In this post I’m going to give two more examples, using the same expression representation from the previous post, and then talk about how to implement our passes using a different representation, without knot-tying.

Example: attaching typing information to Ids

Previously we attached arity and unfolding information to Ids. Now suppose that our language is typed, and up to some point our transformations rely on typing information. Similar to arity and unfolding fields we add one more field to Id:

data Id = Id
  { ..
  , idType :: Maybe Type
  }

The Maybe part is because when we no longer need the types we want to be able to clear the type fields to make the AST smaller. While we have only one heap object per Id, in an average program there’s still a lot of different Ids, and Type representation can get quite large, so this is worthwhile. This makes the working set smaller, which causes less GC work and improves compiler performance.

In our cyclic AST representation the only way to implement this without losing sharing is with a full-pass over the entire program, using knot-tying. The code is similar to the ones in the previous post.

Example: attaching unfoldings to Ids

Remember that in the previous post we represented the AST as:

data Expr
  = IdE Id
  | IntE Int
  | Lam Id Expr
  | App Expr Expr
  | IfE Expr Expr Expr
  | Let Id Expr Expr

data Id = Id
  { idName :: String
    -- ^ Unique name of the identifier
  , idArity :: Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  , idUnfolding :: Maybe Expr
    -- ^ RHS of a binder, used for inlining
  }

In this representation if I have a recursive definition like

let fac = \x . if x then x * fac (x - 1) else 1 in fac 5

In fac used in lambda body I want to be able to do idUnfolding and get the definition of this lambda. So the lambda refers to the Id for fac, and fac refers to the lambda in its idUnfolding field, forming a cycle.

In this representation only way to implement this is with knot-tying. An implementation that maintains a map from binders to their RHSs to update unfoldings of Ids in occurrence position does not work, because when we update an occurrence of the binder in its own RHS (i.e. in a recursive let) we end up invalidating the RHS that we’ve added to the map.

Here’s a knot-tying implementation that adds unfoldings (only the interesting bits):

addUnfoldings :: Expr -> Expr
addUnfoldings = go M.empty
  where
    go :: M.Map String Id -> Expr -> Expr
    go ids e = case e of

      IdE id ->
        IdE (fromMaybe id (M.lookup (idName id) ids))

      Let bndr rhs body ->
        let
          ids' = M.insert (idName bndr) bndr' ids
          rhs' = go ids' rhs
          bndr' = bndr{ idUnfolding = Just rhs' }
        in
          Let bndr{ idUnfolding = Just rhs' } rhs' (go ids' body)

      ...

As before we tie the knot in let case and use it in Id case.

It’s also possible to initialize idUnfolding fields when parsing, using monadic knot-tying (MonadFix). Full code is shown at the end of this post, but the interesting bit is when parsing lets and Ids:

parseLet :: Parser Expr
parseLet = do
    _ <- string "let"
    id_name <- parseIdName
    _ <- char '='

    (id, rhs) <- mfix $ \ ~(id_, _rhs) -> do
      modify (Map.insert id_name id_)
      rhs <- parseExpr
      return (Id{ idName = id_name, idArity = 0, idUnfolding = Just rhs }, rhs)

    _ <- string "in"
    body <- parseExpr
    return (Let id rhs body)

parseId' :: Parser Id
parseId' = do
    name <- parseIdName
    id_map <- get
    let def = Id{ idName = name, idArity = 0, idUnfolding = Nothing }
    return (fromMaybe def (Map.lookup name id_map))

The idea is very similar. When parsing a let we add a thunk for the binder with correct unfolding to a map. The map is then used when parsing Ids in the RHS and body of the let.

An alternative

A well-known way of associating information with identifiers in a compiler is by using a “symbol table”. Instead of adding information about Ids directly in the Id fields, we maintain a table (or multiple tables) that map Ids to the relevant information. Here’s one way to do this in our language:

data Expr
  = IdE String
  ...

data IdInfo = IdInfo
  { idArity :: Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  , idUnfolding :: Maybe Expr
    -- ^ RHS of a binder, used for inlining
  }

type SymTbl = Map.Map String IdInfo

In this representation we have to refer to the table for idArity or idUnfolding. That’s slightly more work than the previous representation where we could simply use the fields of an Id, but a lot of other things become much simpler and efficient.

Here’s dropUnusedBindings in this representation (only the interesting bits, full code is at the end of this post):

dropUnusedBindings :: Expr -> State SymTbl Expr
dropUnusedBindings =
    fmap snd . go Set.empty
  where
    go :: Set.Set String -> Expr -> State SymTbl (Set.Set String, Expr)
    go free_vars e0 = case e0 of

      Let bndr e1 e2 -> do
        (free2, e2') <- go free_vars e2
        if Set.member bndr free2 then do
          (free1, e1') <- go free_vars e1
          setIdArity bndr (countLambdas e1')
          return (Set.delete bndr (Set.union free1 free2), Let bndr e1' e2')
        else
          return (free2, e2')

      ...

Our pass is now stateful (updates the symbol table) and written in monadic style. Knot-tying is gone. We update the symbol table after processing a let RHS. Because Ids no longer have the arity information we don’t need to update anything other than the symbol table.

It’s now trivial to implement addUnfoldings:

addUnfoldings :: Expr -> State SymTbl ()
addUnfoldings e0 = case e0 of

    IdE{} ->
      return ()

    IntE{} ->
      return ()

    Lam arg body ->
      addUnfoldings body

    App e1 e2 -> do
      addUnfoldings e1
      addUnfoldings e2

    IfE e1 e2 e3 -> do
      addUnfoldings e1
      addUnfoldings e2
      addUnfoldings e3

    Let bndr e1 e2 -> do
      addUnfoldings e1
      addUnfoldings e2
      setIdUnfolding bndr e1

Doing it during parsing is also trivial, and shown in the full code at the end of this post. Updating typing information when we no longer need them is simply

dropTypes :: State SymTbl ()
dropTypes = modify (Map.map (\id_info -> id_info{ idType = Nothing }))

We could also maintain a separate table for typing information, in which case all we had to do would be to stop using that table.

Easy!

Final remarks

Cyclic AST representation in a purely functional language necessitates knot-tying and relies on lazy evaluation. A well-known alternative is using symbol tables. It works across languages (does not rely on lazy evaluation) and keeps the code simple.

Cyclic representations make using the information easier, while symbol tables make updating easier. Code for updating the information is shown above and the previous post. For using the information, compare:

-- Get the information in a cyclic representation
... (idUnfolding id) ...

-- Get the information using a symbol table
arity <- getIdUnfolding id

To me the monadic version is not too bad in terms of verbosity or convenience, especially because Haskell makes state passing so easy.

Some of the problems with knot-tying is as explained at the end of the previous post. What I did not mention in the previous post is the problems with efficiency, which are demonstrated better in this post.

In use sites, getIdArity (a map lookup) does more work than idArity (just follows a pointer). While I don’t have any benchmarks on this, I doubt that this is bad enough to make cyclic representation and knot-tying preferable.

Examples in these two posts are inspired by GHC:

In the first post I mostly argued that knot-tying makes things more complicated, and in this post I showed that knot-tying is necessary because of the cyclic representation. If we want to do the same without knot-tying we either have to introduce mutable references (e.g. IORefs) in our AST (not shown in this post), or have to use a non-cyclic representation with symbol tables.

Between these two representations, I think non-cyclic representation with symbol tables is a better choice.

Full code (knot-tying)

-- Tried with GHC 8.6.4

{-# OPTIONS_GHC -Wall #-}

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}

import Data.List
import Data.Maybe
import Prelude hiding (id)

-- mtl-2.2
import Control.Monad.State

-- containers-0.6
import qualified Data.Map as Map
import qualified Data.Set as Set

-- megaparsec-7.0
import Text.Megaparsec hiding (State)
import Text.Megaparsec.Char

-- pretty-show-1.10
import Text.Show.Pretty

data Expr
  = IdE Id
  | IntE Int
  | Lam Id Expr
  | App Expr Expr
  | IfE Expr Expr Expr
  | Let Id Expr Expr
  deriving (Show)

data Id = Id
  { idName :: String
    -- ^ Unique name of the identifier
  , idArity :: Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  , idUnfolding :: Maybe Expr
    -- ^ RHS of a binder, used for inlining
  }

instance Show Id where
  show (Id name arity _) = "(Id " ++ show name ++ " " ++ show arity ++ ")"

--------------------------------------------------------------------------------
-- Initializing unfolding fields in parse time via MonadFix

type IdMap = Map.Map String Id

type Parser = ParsecT String String (State IdMap)

parseExpr :: Parser Expr
parseExpr = do
    exprs <- some $
      choice $
      map (\p -> p <* space)
        [ parseParens, parseIf, parseLam, parseInt,
          parseLet, try parseId ]
    return (foldl1' App exprs)

parseParens, parseIf, parseLam, parseInt,
  parseLet, parseId :: Parser Expr

parseParens = do
    _ <- char '('
    space
    expr <- parseExpr
    _ <- char ')'
    return expr

parseIf = do
    _ <- string "if"
    space
    condE <- parseExpr

    _ <- string "then"
    space
    thenE <- parseExpr
    _ <- string "else"
    space
    elseE <- parseExpr
    return (IfE condE thenE elseE)

parseLam = do
    _ <- char '\\'
    space
    id <- parseId'
    space
    _ <- char '.'
    space
    body <- parseExpr
    return (Lam id body)

parseInt = do
    chars <- some digitChar
    return (IntE (read chars))

parseLet = do
    _ <- string "let"
    space
    id_name <- parseIdName
    space
    _ <- char '='
    space

    (id, rhs) <- mfix $ \ ~(id_, _rhs) -> do
      modify (Map.insert id_name id_)
      rhs <- parseExpr
      return (Id{ idName = id_name, idArity = 0, idUnfolding = Just rhs }, rhs)

    _ <- string "in"
    space
    body <- parseExpr
    return (Let id rhs body)

parseId = IdE <$> parseId'

kws :: Set.Set String
kws = Set.fromList ["if", "then", "else", "let", "in"]

parseIdName :: Parser String
parseIdName = do
    name <- some letterChar
    guard (not (Set.member name kws))
    return name

parseId' :: Parser Id
parseId' = do
    name <- parseIdName
    id_map <- get
    let def = Id{ idName = name, idArity = 0, idUnfolding = Nothing }
    return (fromMaybe def (Map.lookup name id_map))

testPgm :: String -> Expr
testPgm pgm =
    case evalState (runParserT parseExpr "" pgm) Map.empty of
      Left (err_bundle :: ParseErrorBundle String String) ->
        error (errorBundlePretty err_bundle)
      Right expr ->
        expr

instance ShowErrorComponent [Char] where
    showErrorComponent x = x

--------------------------------------------------------------------------------
-- Initializing unfoldings with knot-tying

addUnfoldings :: Expr -> Expr
addUnfoldings = go Map.empty
  where
    go :: Map.Map String Id -> Expr -> Expr
    go ids e = case e of

      -- Interesting bits ------------------------------------------------------
      IdE id ->
        IdE (fromMaybe id (Map.lookup (idName id) ids))

      Let bndr rhs body ->
        let
          ids' = Map.insert (idName bndr) bndr' ids
          rhs' = go ids' rhs
          bndr' = bndr{ idUnfolding = Just rhs' }
        in
          Let bndr{ idUnfolding = Just rhs' } rhs' (go ids' body)
      --------------------------------------------------------------------------

      IntE{} ->
        e

      Lam arg body ->
        Lam arg (go ids body)

      App e1 e2 ->
        App (go ids e1) (go ids e2)

      IfE e1 e2 e3 ->
        IfE (go ids e1) (go ids e2) (go ids e3)

Full code (symbol table)

-- Tried with GHC 8.6.4

{-# OPTIONS_GHC -Wall #-}

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}

import Data.List
import Data.Maybe
import Prelude hiding (id)

-- mtl-2.2
import Control.Monad.State

-- containers-0.6
import qualified Data.Map as Map
import qualified Data.Set as Set

-- megaparsec-7.0
import Text.Megaparsec hiding (State)
import Text.Megaparsec.Char

-- pretty-show-1.10
import Text.Show.Pretty

import Debug.Trace

data Expr
  = IdE String
  | IntE Int
  | Lam String Expr
  | App Expr Expr
  | IfE Expr Expr Expr
  | Let String Expr Expr
  deriving (Show)

data IdInfo = IdInfo
  { idArity :: Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  , idUnfolding :: Maybe Expr
    -- ^ RHS of a binder, used for inlining
  , idType :: Maybe Type
    -- ^ Type of the id.
  }

data Type = Type -- Assume a large type

instance Show IdInfo where
  show (IdInfo arity _ _) = "(IdInfo " ++ show arity ++ ")"

type SymTbl = Map.Map String IdInfo

getIdInfo :: String -> State SymTbl (Maybe IdInfo)
getIdInfo id =
    Map.lookup id <$> get

setIdArity :: String -> Int -> State SymTbl ()
setIdArity id arity = modify (Map.alter alter id)
  where
    alter Nothing =
      Just IdInfo{ idArity = arity, idUnfolding = Nothing, idType = Nothing }
    alter (Just id_info) =
      Just id_info{ idArity = arity }

setIdUnfolding :: String -> Expr -> State SymTbl ()
setIdUnfolding id unfolding = modify (Map.alter alter id)
  where
    alter Nothing =
      Just IdInfo{ idUnfolding = Just unfolding, idArity = 0, idType = Nothing }
    alter (Just id_info) =
      Just id_info{ idUnfolding = Just unfolding }

countLambdas :: Expr -> Int
countLambdas (Lam _ rhs) = 1 + countLambdas rhs
countLambdas _ = 0

dropUnusedBindings :: Expr -> State SymTbl Expr
dropUnusedBindings =
    fmap snd . go Set.empty
  where
    go :: Set.Set String -> Expr -> State SymTbl (Set.Set String, Expr)
    go free_vars e0 = case e0 of

      IdE id ->
        return (Set.insert id free_vars, e0)

      IntE{} ->
        return (free_vars, e0)

      Lam arg body -> do
        (free_vars', body') <- go free_vars body
        return (Set.delete arg free_vars', Lam arg body')

      App e1 e2 -> do
        (free1, e1') <- go free_vars e1
        (free2, e2') <- go free_vars e2
        return (Set.union free1 free2, App e1' e2')

      IfE e1 e2 e3 -> do
        (free1, e1') <- go free_vars e1
        (free2, e2') <- go free_vars e2
        (free3, e3') <- go free_vars e3
        return (Set.unions [free1, free2, free3], IfE e1' e2' e3')

      Let bndr e1 e2 -> do
        (free2, e2') <- go free_vars e2
        if Set.member bndr free2 then do
          (free1, e1') <- go free_vars e1
          trace (ppShow e1') (return ())
          setIdArity bndr (countLambdas e1')
          return (Set.delete bndr (Set.union free1 free2), Let bndr e1' e2')
        else
          return (free2, e2')

addUnfoldings :: Expr -> State SymTbl ()
addUnfoldings e0 = case e0 of

    IdE{} ->
      return ()

    IntE{} ->
      return ()

    Lam _ body ->
      addUnfoldings body

    App e1 e2 -> do
      addUnfoldings e1
      addUnfoldings e2

    IfE e1 e2 e3 -> do
      addUnfoldings e1
      addUnfoldings e2
      addUnfoldings e3

    Let bndr e1 e2 -> do
      addUnfoldings e1
      addUnfoldings e2
      setIdUnfolding bndr e1

dropTypes :: State SymTbl ()
dropTypes = modify (Map.map (\id_info -> id_info{ idType = Nothing }))

pgm :: Expr
pgm = Let "fac" rhs body
  where
    rhs = Lam "x" (IfE (IdE "x") (App (App (IdE "*") (IdE "x"))
                                      (App (IdE "fac")
                                           (App (App (IdE "-") (IdE "x")) (IntE 1))))
                                 (IntE 1))
    body = App (IdE "fac") (IntE 5)

--------------------------------------------------------------------------------
-- Initializing unfolding fields in parse time, the boring way

type Parser = ParsecT String String (State SymTbl)

parseExpr :: Parser Expr
parseExpr = do
    exprs <- some $
      choice $
      map (\p -> p <* space)
        [ parseParens, parseIf, parseLam, parseInt,
          parseLet, try parseId ]
    return (foldl1' App exprs)

parseParens, parseIf, parseLam, parseInt,
  parseLet, parseId :: Parser Expr

parseParens = do
    _ <- char '('
    space
    expr <- parseExpr
    _ <- char ')'
    return expr

parseIf = do
    _ <- string "if"
    space
    condE <- parseExpr

    _ <- string "then"
    space
    thenE <- parseExpr
    _ <- string "else"
    space
    elseE <- parseExpr
    return (IfE condE thenE elseE)

parseLam = do
    _ <- char '\\'
    space
    id <- parseId'
    space
    _ <- char '.'
    space
    body <- parseExpr
    return (Lam id body)

parseInt = do
    chars <- some digitChar
    return (IntE (read chars))

parseLet = do
    _ <- string "let"
    space
    id <- parseId'
    space
    _ <- char '='
    space
    rhs <- parseExpr
    _ <- string "in"
    space
    body <- parseExpr
    lift (setIdUnfolding id rhs)
    return (Let id rhs body)

parseId = IdE <$> parseId'

kws :: Set.Set String
kws = Set.fromList ["if", "then", "else", "let", "in"]

parseId' :: Parser String
parseId' = do
    name <- some letterChar
    guard (not (Set.member name kws))
    return name

testPgm :: String -> Expr
testPgm pgm =
    case evalState (runParserT parseExpr "" pgm) Map.empty of
      Left (err_bundle :: ParseErrorBundle String String) ->
        error (errorBundlePretty err_bundle)
      Right expr ->
        expr

instance ShowErrorComponent [Char] where
    showErrorComponent x = x