osa1 github about atom

Knot-tying: why and how (and my opinions on it)

February 21, 2020 - Tagged as: en, haskell, ghc.

Suppose I have this simple language:

data Expr
  = IdE Id
  | IntE Int
  | Lam Id Expr
  | App Expr Expr
  | IfE Expr Expr Expr
  | Let Id Expr Expr

When generating code, for an identifier that stands for a lambda, I want to know the arity of the lambda, so that I can generate more efficient code. While in this language a lambda takes only one argument, if I have something like

let f = \x . \y . \z . ...
 in ...

I consider f as having arity 3.

One way to implement this is having this information attached to every Id:

data Id = Id
  { idName :: String
    -- ^ Unique name of the identifier
  , idArity :: Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  }

This way of associating information to Ids makes some things very simple. For example, if I’m generating code for this application:

f 1 2

In AST:

App (App (IdE (Id { idName = "f", idArity = 3 })) (IntE 1)) (IntE 2)

I can simply use the idArity field to see the arity of the function being applied. It doesn’t get any simpler than this.

Problem 1: redundant allocations

In a program we usually have many references to a single Id, whether it’s for a top-level function or an argument. If we allocate an Id for every occurrence that’s a lot of redundant allocations that make the AST representation larger, and affects compiler performance.

For example, if I have this expression:

f x + f y

A naive representation of this would be

App
  (App
     (IdE Id { idName = "+" , idArity = 2 })
     (App
        (IdE Id { idName = "f" , idArity = 0 })
        (IdE Id { idName = "z" , idArity = 0 })))
  (App
     (IdE Id { idName = "f" , idArity = 0 })
     (IdE Id { idName = "t" , idArity = 0 }))

Here for every occurrence of f we have a new Id, and these Ids all have the same arity. This is two Id heap objects used for the same identifier.

A more efficient representation would be

let f = Id { idName = "f", idArity = 0 } in
App
  (App
     (IdE Id { idName = "+" , idArity = 2 })
     (App
        (IdE f)
        (IdE Id { idName = "z" , idArity = 0 })))
  (App
     (IdE f)
     (IdE Id { idName = "t" , idArity = 0 }))

Here we only have one heap object for f, and all uses refer to that one object.

This is actually not hard to fix: we maintain a map from Id names to the actual Ids. When we see a let we add the LHS to the map. When we see an identifier we lookup. Easy.

Problem 2: invalidating information during transformations

Suppose I want to implement a pass that drops unused bindings. For example:

let f = let a = e1
         in \x . e2
 in f z + f t

Here if e2 doesn’t use a I want to drop the binding:

let f = \x . e2
 in f z + f t

The AST for the original program is:

Let
  Id { idName = "f" , idArity = 0 }
  (Let
     Id { idName = "a" , idArity = 0 }
     <e1>
     (Lam Id { idName = "x" , idArity = 0 } <e2>))
  (App
     (App
        (IdE Id { idName = "+" , idArity = 2 })
        (App
           (IdE Id { idName = "f" , idArity = 0 })
           (IdE Id { idName = "z" , idArity = 0 })))
     (App
        (IdE Id { idName = "f" , idArity = 0 })
        (IdE Id { idName = "t" , idArity = 0 })))

Here’s a naive implementation of this pass:

dropUnusedBindings :: Expr -> Expr
dropUnusedBindings = snd . go Set.empty
  where
    go free_vars e0 = case e0 of

      IdE id ->
        (Set.insert (idName id) free_vars, e0)

      IntE{} ->
        (free_vars, e0)

      Lam arg body ->
        bimap (Set.delete (idName arg)) (Lam arg)
              (go free_vars body)

      App e1 e2 ->
        let
          (free1, e1') = go free_vars e1
          (free2, e2') = go free_vars e2
        in
          (Set.union free1 free2, App e1' e2')

      IfE e1 e2 ->
        let
          (free1, e1') = go free_vars e1
          (free2, e2') = go free_vars e2
          (free3, e3') = go free_vars e3
        in
          (Set.unions [free1, free2, free3], IfE e1' e2' e3')

      Let bndr e1 e2 ->
        let
          (free1, e1') = first (Set.delete (idName bndr)) (go free_vars e1)
          (free2, e2') = go free_vars e2
        in
          if Set.member (idName bndr) free2
            then (Set.delete (idName bndr) (Set.union free1 free2),
                  Let (updateIdArity bndr e1') e1' e2')
            else (free2, e2')

updateIdArity :: Id -> Expr -> Id
updateIdArity id rhs = id{ idArity = countLambdas rhs }

countLambdas :: Expr -> Int
countLambdas (Lam _ rhs) = 1 + countLambdas rhs
countLambdas _ = 0

The problem with this pass is that it changes arity of binders, but doesn’t update the idAritys of occurrences. Here’s what I get if I run this over the original AST:

Let
  Id { idName = "f" , idArity = 1 }
  (Lam Id { idName = "x" , idArity = 0 } <e2>)
  (App
     (App
        (IdE Id { idName = "+" , idArity = 2 })
        (App
           (IdE Id { idName = "f" , idArity = 0 })
           (IdE Id { idName = "z" , idArity = 0 })))
     (App
        (IdE Id { idName = "f" , idArity = 0 })
        (IdE Id { idName = "t" , idArity = 0 })))

Note how f, which was not a lambda binder previously, became a lambda binder with arity 1. The pass correctly updated f’s idArity in the binder position, but it did not update it in the occurrences! Indeed, in this representation it’s not easy to do this efficiently.

Even if we solved the first problem and had only one closure for f, the updateIdArity step in this pass allocates a new Id and loses sharing. So we would end up with something like:

let f = Id { idName = "f", idArity = 0 } in
Let
  Id { idName = "f" , idArity = 1 }
  (Lam Id { idName = "x" , idArity = 0 } <e2>)
  (App
     (App
        (IdE Id { idName = "+" , idArity = 2 })
        (App
           (IdE f)
           (IdE Id { idName = "z" , idArity = 0 })))
     (App
        (IdE f)
        (IdE Id { idName = "t" , idArity = 0 })))

The arity of f in the use sites are still wrong, and we lost sharing.

Knot-tying

Knot-tying is a way of solving both of these in one step. I find it quite hard to explain in words so I’ll show the code (only the interesting bits):

dropUnusedBindings :: Expr -> Expr
dropUnusedBindings =
    snd . go Map.empty Set.empty
  where
    go :: Map.Map String Id -> Set.Set String -> Expr -> (Set.Set String, Expr)
    go binders free_vars e0 = case e0 of

      IdE id ->
        (Set.insert (idName id) free_vars, IdE (fromMaybe id (Map.lookup (idName id) binders)))

      Let bndr@Id{ idName = bndr_name } e1 e2 ->
        let
          bndr' = updateIdArity bndr e1'
          binders' = Map.insert bndr_name bndr' binders
          (free1, e1') = first (Set.delete bndr_name) (go binders' free_vars e1)
          (free2, e2') = go binders' free_vars e2
        in
          if Set.member bndr_name free2
            then (Set.delete bndr_name (Set.union free1 free2),
                  Let bndr' e1' e2')
            else (free2, e2')

      ...

The differences from the original version:

This technique relies heavily on lazy evaluation. In the original example the AST is not recursive, but suppose we also want to record RHSs of let binders in Ids, to be used for inlining:

data Id = Id
  { ...
  , idUnfolding :: Maybe Expr
    -- ^ RHS of a let binding, used for inlining
  }

Now once we implement sharing (solving problem 1) ASTs with recursive definitions will become cyclic. A simple example:

let fac = \x . if x then x * fac (x - 1) else 1 in fac 5

This will be represented as something like

pgm = Let fac_id rhs body
  where
    fac_id = Id { idName = "fac", idArity = 0, idUnfolding = Just rhs }
    rhs = Lam x_id (IfE (IdE x_id)
                        (App (App (IdE star_id) (IdE x_id))
                             (App (IdE fac_id) (App (App (IdE minus_id) (IdE x_id))
                                                    (IntE 1))))
                                  (IntE 1))
    body = App (IdE fac_id) (IntE 5)

    x_id = Id { idName = "x", idArity = 0, idUnfolding = Nothing }
    star_id = Id { idName = "*", idArity = 2, idUnfolding = Nothing }
    minus_id = Id { idName = "-", idArity = 2, idUnfolding = Nothing }

Here fac_id refers to rhs, which refers to fac_id, forming a cycle.

The knot-tying implementation of dropUnusedBindings works even in cases like this. We just need to update updateIdArity to update the unfolding, when it’s available:

updateIdArity :: Id -> Expr -> Id
updateIdArity id rhs =
    id{ idArity = countLambdas rhs
      , idUnfolding = idUnfolding id $> rhs }

This is a bit hard to try, but if I implement a Show instance for Id that doesn’t print the unfolding (to avoid looping), make fac_id’s arity 0, and call dropUnusedBindings this is the AST I get:

Let
  (Id "fac" 1)
  (Lam
     (Id "x" 0)
     (IfE
        (IdE (Id "x" 0))
        (App
           (App (IdE (Id "*" 2)) (IdE (Id "x" 0)))
           (App
              (IdE (Id "fac" 1))
              (App (App (IdE (Id "-" 2)) (IdE (Id "x" 0))) (IntE 1))))
        (IntE 1)))
  (App (IdE (Id "fac" 1)) (IntE 5))

All uses of fac have correct arity! Similarly I can do something hacky like this in GHCi to check that the unfolding has correct arity for uses of fac too:

ghci> let Let lhs _ _ = dropUnusedBindings pgm
ghci> putStrLn (ppShow (idUnfolding lhs))
Just
  (Lam
     (Id "x" 0)
     (IfE
        (IdE (Id "x" 0))
        (App
           (App (IdE (Id "*" 2)) (IdE (Id "x" 0)))
           (App
              (IdE (Id "fac" 1))
              (App (App (IdE (Id "-" 2)) (IdE (Id "x" 0))) (IntE 1))))
        (IntE 1)))

Nice!

… or is it?

The main problem with this technique is that it’s very difficult to understand. Even after working on different knot-tying code in GHC and implementing my own knot-tying passes, the recursive let bindings in the Let case above is still mind-boggling to me.

Secondly, it’s really hard to reason about the evaluation order of things in knot-tying code. You might think that this shouldn’t be an issue in a purely functional implementation, but in my experience any non-trivial compiler pass, even when implemented in a purely functional style, still needs debugging. Even if it’s not buggy, you may want to trace the evaluation and print a few things to understand how the code works.

Knot-tying code makes this, which should be absolutely trivial in any reasonable code base, very difficult. If you end up evaluating just the right places with your print statements you end looping. For example, here’s our AST with a few bang patterns:

data Expr
  = IdE !Id
  | IntE Int
  | Lam Id Expr
  | App !Expr !Expr
  | IfE Expr !Expr Expr
  | Let Id Expr Expr

data Id = Id
  { idName :: String
  , idArity :: !Int
  }

If you run the same program above using this AST definition you’ll see that the pass now loops. Note that I’ve removed the idUnfolding field just to demonstrate that this doesn’t happen because we have a loop in the AST.

It’s even more frustrating when what you’re debugging is a loop. You add a few prints, and scratch your head thinking why none of your prints are working even though the algorithm is clearly looping. What’s really happening is that the code is indeed looping, but for a different reason…

Finally, because making things more strict potentially breaks things, knot-tying makes fixing some memory leaks very hard. For example, we may have many passes on our AST, one of them being our knot-tying pass. Some of these passes may be very leaky, and instead of adding strict applications or bang patterns to dozens of places, we may want to add bangs to only a few places in the AST. But that, as demonstrated above, causes our knot-tying pass to loop.

Opinions

GHC makes use of knot-tying extensively, which has always been one of the pain points for me since my first days contributing to GHC. I vaguely remember, I was a graduate student at Indiana University at the time, making my first contributions to GHC. I remember finding it refreshing to be able to simply do idType and get type of an identifier in GHC, as opposed to using a symbol table, which I’d been doing in some of the other compilers I worked on in the past.

At the same time, I was constantly confused that my simple print statements added in some front-end pass makes the compiler loop. I had no idea what could be the reason. I had no idea that the thing I found so refreshing is also the reason why debugging and tracing were so much harder.

Suffice it to say, I don’t like knot-tying. If I had to use knot-tying in my project I’d probably reconsider how I represent my data instead. For example, if we simply used an unique number for our identifiers and maintained a symbol table to map the unique numbers to actual Ids then we wouldn’t have cycles for recursive functions in the AST and wouldn’t need knot-tying. Updating something about an Id would be a simple update in the symbol table.

Full code

-- Tried with GHC 8.6.4

{-# OPTIONS_GHC -Wall #-}

module Main where

import Data.Bifunctor
import Data.Functor
import Data.Maybe
import Prelude hiding (id)

-- containers-0.6
import qualified Data.Map as Map
import qualified Data.Set as Set

-- pretty-show-1.10
import Text.Show.Pretty

{-
data Expr
  = IdE !Id
  | IntE Int
  | Lam Id Expr
  | App !Expr !Expr
  | IfE Expr !Expr Expr
  | Let Id Expr Expr
  | Placeholder String
  deriving (Show)

data Id = Id
  { idName :: String
    -- ^ Unique name of the identifier
  , idArity :: !Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  }
-}

data Expr
  = IdE Id
  | IntE Int
  | Lam Id Expr
  | App Expr Expr
  | IfE Expr Expr Expr
  | Let Id Expr Expr
  | Placeholder String
  deriving (Show)

data Id = Id
  { idName :: String
    -- ^ Unique name of the identifier
  , idArity :: Int
    -- ^ Arity of a lambda. 0 for non-lambdas.
  , idUnfolding :: Maybe Expr
    -- ^ RHS of a binder, used for inlining
  }

instance Show Id where
  show (Id name arity _) = "(Id " ++ show name ++ " " ++ show arity ++ ")"

{-
f_id = Id { idName = "f", idArity = 0 }
a_id = Id { idName = "a", idArity = 0 }
x_id = Id { idName = "x", idArity = 0 }
z_id = Id { idName = "z", idArity = 0 }
t_id = Id { idName = "t", idArity = 0 }
plus_id = Id { idName = "+", idArity = 2 }


f_x_plus_f_y = (App (App (IdE plus_id) (App (IdE f_id) (IdE z_id)))
                     (App (IdE f_id) (IdE t_id)))

ast1 = Let f_id (Let a_id (Placeholder "e1") (Lam x_id (Placeholder "e2"))) f_x_plus_f_y

ast2 = Let a_id (Placeholder "e1")
           (Let f_id (Lam x_id (Placeholder "e2"))
                     f_x_plus_f_y)
-}

updateIdArity :: Id -> Expr -> Id
updateIdArity id rhs =
  id{ idArity = countLambdas rhs,
      idUnfolding = idUnfolding id $> rhs }

countLambdas :: Expr -> Int
countLambdas (Lam _ rhs) = 1 + countLambdas rhs
countLambdas _ = 0

dropUnusedBindings :: Expr -> Expr
dropUnusedBindings =
    snd . go Map.empty Set.empty
  where
    go :: Map.Map String Id -> Set.Set String -> Expr -> (Set.Set String, Expr)
    go binders free_vars e0 = case e0 of

      IdE id ->
        (Set.insert (idName id) free_vars, IdE (fromMaybe id (Map.lookup (idName id) binders)))

      IntE{} ->
        (free_vars, e0)

      Lam arg body ->
        bimap (Set.delete (idName arg)) (Lam arg)
              (go binders free_vars body)

      App e1 e2 ->
        let
          (free1, e1') = go binders free_vars e1
          (free2, e2') = go binders free_vars e2
        in
          (Set.union free1 free2, App e1' e2')

      IfE e1 e2 e3 ->
        let
          (free1, e1') = go binders free_vars e1
          (free2, e2') = go binders free_vars e2
          (free3, e3') = go binders free_vars e3
        in
          (Set.unions [free1, free2, free3], IfE e1' e2' e3')

      Let bndr@Id{ idName = bndr_name } e1 e2 ->
        let
          bndr' = updateIdArity bndr e1'
          binders' = Map.insert bndr_name bndr' binders
          (free1, e1') = first (Set.delete bndr_name) (go binders' free_vars e1)
          (free2, e2') = go binders' free_vars e2
        in
          if Set.member bndr_name free2
            then (Set.delete bndr_name (Set.union free1 free2),
                  Let bndr' e1' e2')
            else (free2, e2')

      Placeholder{} ->
        (free_vars, e0)

pgm :: Expr
pgm = Let fac_id rhs body
  where
    fac_id = Id { idName = "fac", idArity = 0, idUnfolding = Just rhs }
    rhs = Lam x_id (IfE (IdE x_id) (App (App (IdE star_id) (IdE x_id))
                                        (App (IdE fac_id)
                                             (App (App (IdE minus_id) (IdE x_id)) (IntE 1))))
                                   (IntE 1))
    body = App (IdE fac_id) (IntE 5)

    x_id = Id { idName = "x", idArity = 0, idUnfolding = Nothing }
    star_id = Id { idName = "*", idArity = 2, idUnfolding = Nothing }
    minus_id = Id { idName = "-", idArity = 2, idUnfolding = Nothing }

main :: IO ()
main = putStrLn (ppShow (dropUnusedBindings pgm))

Thanks to Oleg Grenrus for reading a draft of this.