diff options
author | Tom Smeding <tom@tomsmeding.com> | 2023-09-16 18:03:44 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2023-09-16 18:03:44 +0200 |
commit | 7095bcf4910e2b1525234ca8e88f4effc25315bd (patch) | |
tree | 50ca15ca9d774c739b7320e4b68ef2b600e52ab1 | |
parent | 35cc10682f35dafba98000bf35191896a6432624 (diff) |
Pretty print
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 166 | ||||
-rw-r--r-- | src/CHAD.hs | 5 | ||||
-rw-r--r-- | src/Example.hs | 30 |
4 files changed, 202 insertions, 0 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index dd5bb27..ac0df0f 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -11,6 +11,7 @@ build-type: Simple library exposed-modules: AST + AST.Pretty AST.Weaken CHAD -- Compile diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs new file mode 100644 index 0000000..c1d6c88 --- /dev/null +++ b/src/AST/Pretty.hs @@ -0,0 +1,166 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE TupleSections #-} +module AST.Pretty where + +import Control.Monad (ap) +import Data.List (intersperse) +import Data.Functor.Const + +import AST + + +data Val f env where + VTop :: Val f '[] + VPush :: f t -> Val f env -> Val f (t : env) + +type SVal = Val (Const String) + +valprj :: Val f env -> Idx env t -> f t +valprj (VPush x _) IZ = x +valprj (VPush _ env) (IS i) = valprj env i +valprj VTop i = case i of {} + +newtype M a = M { runM :: Int -> (a, Int) } + deriving (Functor) +instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } +instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) } + +genId :: M Int +genId = M (\i -> (i, i + 1)) + +genName :: M String +genName = ('x' :) . show <$> genId + +ppExpr :: SList STy env -> Expr x env t -> String +ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" + where + mkVal :: SList STy env -> M (SVal env) + mkVal SNil = return VTop + mkVal (SCons _ v) = do + val <- mkVal v + name <- genName + return (VPush (Const name) val) + +ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS +ppExpr' d val = \case + EVar _ _ i -> return $ showString $ getConst $ valprj val i + + etop@ELet{} -> do + let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS) + collect val' (ELet _ rhs body) = do + name <- genName + (binds, core) <- collect (VPush (Const name) val') body + rhs' <- ppExpr' 0 val' rhs + return ((name, rhs') : binds, core) + collect val' e = ([],) <$> ppExpr' 0 val' e + + (binds, core) <- collect val etop + let (open, close) = case binds of + [_] -> ("{ ", " }") + _ -> ("", "") + return $ showParen (d > 0) $ + showString ("let " ++ open) + . foldr (.) id (intersperse (showString " ; ") + (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds)) + . showString (close ++ " in ") + . core + + EPair _ a b -> do + a' <- ppExpr' 0 val a + b' <- ppExpr' 0 val b + return $ showString "(" . a' . showString ", " . b' . showString ")" + + EFst _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "fst " . e' + + ESnd _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "snd " . e' + + ENil _ -> return $ showString "()" + + EInl _ _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "inl " . e' + + EInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "inr " . e' + + ECase _ e a b -> do + e' <- ppExpr' 0 val e + name1 <- genName + a' <- ppExpr' 0 (VPush (Const name1) val) a + name2 <- genName + b' <- ppExpr' 0 (VPush (Const name2) val) b + return $ showParen (d > 0) $ + showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' + . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" + + EConst _ ty v -> return $ showString $ case ty of + STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v + + EOp _ op (EPair _ a b) + | (Infix, ops) <- operator op -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 9 val b + return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b' + + EOp _ op e -> do + e' <- ppExpr' 11 val e + let ops = case operator op of + (Infix, s) -> "(" ++ s ++ ")" + (Prefix, s) -> s + return $ showParen (d > 10) $ showString (ops ++ " ") . e' + + EMOne venv i e -> do + let venvlen = length (unSList venv) + varname = 'v' : show (venvlen - idx2int i) + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ + showString ("one " ++ show varname ++ " ") . e' + + EMScope e -> do + let venv = case typeOf e of STEVM v _ -> v + venvlen = length (unSList venv) + varname = 'v' : show venvlen + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ + showString ("scope " ++ show varname ++ " ") . e' + + EMReturn _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString ("return ") . e' + + EMBind a b -> do + a' <- ppExpr' 0 val a + name <- genName + b' <- ppExpr' 0 (VPush (Const name) val) b + return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b' + + EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) + + _ -> undefined + +data Fixity = Prefix | Infix + deriving (Show) + +operator :: SOp a t -> (Fixity, String) +operator OAdd{} = (Infix, "+") +operator OMul{} = (Infix, "*") +operator ONeg{} = (Prefix, "negate") +operator OLt{} = (Infix, "<") +operator OLe{} = (Infix, "<=") +operator OEq{} = (Infix, "==") +operator ONot = (Prefix, "not") + +idx2int :: Idx env t -> Int +idx2int IZ = 0 +idx2int (IS n) = 1 + idx2int n diff --git a/src/CHAD.hs b/src/CHAD.hs index 26c918e..d0358b8 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -238,6 +238,11 @@ d2op op e d = case op of STF64 -> float STBool -> EInl ext (STPair STNil STNil) (ENil ext) +freezeRet :: Ret env t + -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value + -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E env) TNil)) +freezeRet (Ret e0 e1 e2) d = letBinds e0 $ EPair ext e1 (ELet ext d e2) + drev :: SList STy env -> Ex env t -> Ret env t drev senv = \case EVar _ t i -> diff --git a/src/Example.hs b/src/Example.hs index 89b2082..99574c5 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -2,13 +2,43 @@ module Example where import AST +import AST.Pretty import CHAD bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) +senv1 :: SList STy [TScal TF32, TScal TF32] +senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil + -- x y |- x * y + x +-- +-- let x3 = (x1, x2) +-- x4 = ((*) x3, x1) +-- in ( (+) x4 +-- , let x5 = 1.0 +-- x6 = inr (x5, x5) +-- in case x6 of +-- Inl x7 -> return () +-- Inr x8 -> +-- let x9 = fst x8 +-- x10 = inr (snd x3 * x9, fst x3 * x9) +-- in case x10 of +-- Inl x11 -> return () +-- Inr x12 -> +-- let x13 = fst x12 +-- in one "v1" x13 >>= \x14 -> +-- let x15 = snd x12 +-- in one "v2" x15 >>= \x16 -> +-- let x17 = snd x8 +-- in one "v1" x17) +-- +-- ( (x1 * x2) + x1 +-- , let x5 = 1.0 +-- in do one "v1" (x2 * x5) +-- one "v2" (x1 * x5) +-- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex1 = bin (OAdd STF32) |