summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2023-09-16 18:03:44 +0200
committerTom Smeding <tom@tomsmeding.com>2023-09-16 18:03:44 +0200
commit7095bcf4910e2b1525234ca8e88f4effc25315bd (patch)
tree50ca15ca9d774c739b7320e4b68ef2b600e52ab1
parent35cc10682f35dafba98000bf35191896a6432624 (diff)
Pretty print
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST/Pretty.hs166
-rw-r--r--src/CHAD.hs5
-rw-r--r--src/Example.hs30
4 files changed, 202 insertions, 0 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index dd5bb27..ac0df0f 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -11,6 +11,7 @@ build-type: Simple
library
exposed-modules:
AST
+ AST.Pretty
AST.Weaken
CHAD
-- Compile
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
new file mode 100644
index 0000000..c1d6c88
--- /dev/null
+++ b/src/AST/Pretty.hs
@@ -0,0 +1,166 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE TupleSections #-}
+module AST.Pretty where
+
+import Control.Monad (ap)
+import Data.List (intersperse)
+import Data.Functor.Const
+
+import AST
+
+
+data Val f env where
+ VTop :: Val f '[]
+ VPush :: f t -> Val f env -> Val f (t : env)
+
+type SVal = Val (Const String)
+
+valprj :: Val f env -> Idx env t -> f t
+valprj (VPush x _) IZ = x
+valprj (VPush _ env) (IS i) = valprj env i
+valprj VTop i = case i of {}
+
+newtype M a = M { runM :: Int -> (a, Int) }
+ deriving (Functor)
+instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
+instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) }
+
+genId :: M Int
+genId = M (\i -> (i, i + 1))
+
+genName :: M String
+genName = ('x' :) . show <$> genId
+
+ppExpr :: SList STy env -> Expr x env t -> String
+ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
+ where
+ mkVal :: SList STy env -> M (SVal env)
+ mkVal SNil = return VTop
+ mkVal (SCons _ v) = do
+ val <- mkVal v
+ name <- genName
+ return (VPush (Const name) val)
+
+ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
+ppExpr' d val = \case
+ EVar _ _ i -> return $ showString $ getConst $ valprj val i
+
+ etop@ELet{} -> do
+ let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
+ collect val' (ELet _ rhs body) = do
+ name <- genName
+ (binds, core) <- collect (VPush (Const name) val') body
+ rhs' <- ppExpr' 0 val' rhs
+ return ((name, rhs') : binds, core)
+ collect val' e = ([],) <$> ppExpr' 0 val' e
+
+ (binds, core) <- collect val etop
+ let (open, close) = case binds of
+ [_] -> ("{ ", " }")
+ _ -> ("", "")
+ return $ showParen (d > 0) $
+ showString ("let " ++ open)
+ . foldr (.) id (intersperse (showString " ; ")
+ (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds))
+ . showString (close ++ " in ")
+ . core
+
+ EPair _ a b -> do
+ a' <- ppExpr' 0 val a
+ b' <- ppExpr' 0 val b
+ return $ showString "(" . a' . showString ", " . b' . showString ")"
+
+ EFst _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "fst " . e'
+
+ ESnd _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "snd " . e'
+
+ ENil _ -> return $ showString "()"
+
+ EInl _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "inl " . e'
+
+ EInr _ _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "inr " . e'
+
+ ECase _ e a b -> do
+ e' <- ppExpr' 0 val e
+ name1 <- genName
+ a' <- ppExpr' 0 (VPush (Const name1) val) a
+ name2 <- genName
+ b' <- ppExpr' 0 (VPush (Const name2) val) b
+ return $ showParen (d > 0) $
+ showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
+ . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"
+
+ EConst _ ty v -> return $ showString $ case ty of
+ STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
+
+ EOp _ op (EPair _ a b)
+ | (Infix, ops) <- operator op -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 9 val b
+ return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b'
+
+ EOp _ op e -> do
+ e' <- ppExpr' 11 val e
+ let ops = case operator op of
+ (Infix, s) -> "(" ++ s ++ ")"
+ (Prefix, s) -> s
+ return $ showParen (d > 10) $ showString (ops ++ " ") . e'
+
+ EMOne venv i e -> do
+ let venvlen = length (unSList venv)
+ varname = 'v' : show (venvlen - idx2int i)
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $
+ showString ("one " ++ show varname ++ " ") . e'
+
+ EMScope e -> do
+ let venv = case typeOf e of STEVM v _ -> v
+ venvlen = length (unSList venv)
+ varname = 'v' : show venvlen
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $
+ showString ("scope " ++ show varname ++ " ") . e'
+
+ EMReturn _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString ("return ") . e'
+
+ EMBind a b -> do
+ a' <- ppExpr' 0 val a
+ name <- genName
+ b' <- ppExpr' 0 (VPush (Const name) val) b
+ return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
+
+ EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
+
+ _ -> undefined
+
+data Fixity = Prefix | Infix
+ deriving (Show)
+
+operator :: SOp a t -> (Fixity, String)
+operator OAdd{} = (Infix, "+")
+operator OMul{} = (Infix, "*")
+operator ONeg{} = (Prefix, "negate")
+operator OLt{} = (Infix, "<")
+operator OLe{} = (Infix, "<=")
+operator OEq{} = (Infix, "==")
+operator ONot = (Prefix, "not")
+
+idx2int :: Idx env t -> Int
+idx2int IZ = 0
+idx2int (IS n) = 1 + idx2int n
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 26c918e..d0358b8 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -238,6 +238,11 @@ d2op op e d = case op of
STF64 -> float
STBool -> EInl ext (STPair STNil STNil) (ENil ext)
+freezeRet :: Ret env t
+ -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value
+ -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E env) TNil))
+freezeRet (Ret e0 e1 e2) d = letBinds e0 $ EPair ext e1 (ELet ext d e2)
+
drev :: SList STy env -> Ex env t -> Ret env t
drev senv = \case
EVar _ t i ->
diff --git a/src/Example.hs b/src/Example.hs
index 89b2082..99574c5 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -2,13 +2,43 @@
module Example where
import AST
+import AST.Pretty
import CHAD
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)
+senv1 :: SList STy [TScal TF32, TScal TF32]
+senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
+
-- x y |- x * y + x
+--
+-- let x3 = (x1, x2)
+-- x4 = ((*) x3, x1)
+-- in ( (+) x4
+-- , let x5 = 1.0
+-- x6 = inr (x5, x5)
+-- in case x6 of
+-- Inl x7 -> return ()
+-- Inr x8 ->
+-- let x9 = fst x8
+-- x10 = inr (snd x3 * x9, fst x3 * x9)
+-- in case x10 of
+-- Inl x11 -> return ()
+-- Inr x12 ->
+-- let x13 = fst x12
+-- in one "v1" x13 >>= \x14 ->
+-- let x15 = snd x12
+-- in one "v2" x15 >>= \x16 ->
+-- let x17 = snd x8
+-- in one "v1" x17)
+--
+-- ( (x1 * x2) + x1
+-- , let x5 = 1.0
+-- in do one "v1" (x2 * x5)
+-- one "v2" (x1 * x5)
+-- one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex1 =
bin (OAdd STF32)