summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-04-18 12:53:43 +0200
committerTom Smeding <t.j.smeding@uu.nl>2025-04-18 12:53:43 +0200
commitbd5d0458017862b984b9caf0975c135d154e8515 (patch)
treed6306079efb457afd9d5cb52defe2b1a05c94a6e
parent0a9e6dfc1accf9dc0254f0c720f633dab6e71f42 (diff)
pretty: Print arguments of open expression
-rw-r--r--src/AST/Pretty.hs15
-rw-r--r--src/Example.hs2
-rw-r--r--src/ForwardAD.hs2
-rw-r--r--src/Interpreter.hs46
-rw-r--r--test/Main.hs18
5 files changed, 51 insertions, 32 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 604133b..da4f391 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -10,8 +10,9 @@
module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where
import Control.Monad (ap)
-import Data.List (intersperse)
+import Data.List (intersperse, intercalate)
import Data.Functor.Const
+import qualified Data.Functor.Product as Product
import Data.String (fromString)
import Prettyprinter
import Prettyprinter.Render.String
@@ -67,8 +68,16 @@ genNameIfUsedIn = genNameIfUsedIn' "x"
pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
pprintExpr = putStrLn . ppExpr knownEnv
-ppExpr :: PrettyX x => SList f env -> Expr x env t -> String
-ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1)
+ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
+ppExpr senv e = render $ fst . flip runM 1 $ do
+ val <- mkVal senv
+ e' <- ppExpr' 0 val e
+ let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
+ return $ group $ flatAlt
+ (hang 2 $
+ ppString lam
+ <> hardline <> e')
+ (ppString lam <+> e')
where
mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
diff --git a/src/Example.hs b/src/Example.hs
index 4fa8d5a..3623d03 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -158,7 +158,7 @@ neuralGo =
simplifyN 20 $
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
- (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
+ (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
(primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
_ -> undefined
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index af35f91..b7036dd 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -223,7 +223,7 @@ data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (D
makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
makeFwdADArtifactInterp env expr =
let dexpr = dfwdDN expr
- in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
{-# NOINLINE makeFwdADArtifactCompile #-}
makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index ddc3479..572f2bd 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -24,6 +24,7 @@ import Control.Monad (foldM, join, when, forM_)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
+import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
import System.IO (hPutStrLn, stderr)
@@ -48,35 +49,39 @@ runAcM (AcM m) = unsafePerformIO m
acmDebugLog :: String -> AcM s ()
acmDebugLog s = AcM (hPutStrLn stderr s)
+data V t = V (STy t) (Rep t)
+
interpret :: Ex '[] t -> Rep t
-interpret = interpretOpen False SNil
+interpret = interpretOpen False SNil SNil
-- | Bool: whether to trace execution with debug prints (very verbose)
-interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t
-interpretOpen prints env e =
+interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env venv e =
runAcM $
let ?depth = 0
?prints = prints
- in interpret' env e
+ in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e
-interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int)
+ => SList V env -> Ex env t -> AcM s (Rep t)
interpret' env e = do
+ let tenv = slistMap (\(V t _) -> t) env
let dep = ?depth
let lenlimit = max 20 (100 - dep)
let replace a b = map (\c -> if c == a then b else c)
let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
| otherwise = replace '\n' ' ' s
- when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e)
res <- let ?depth = dep + 1 in interpret'Rec env e
when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
return res
-interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t)
interpret'Rec env = \case
- EVar _ _ i -> case slistIdx env i of Value x -> return x
+ EVar _ _ i -> case slistIdx env i of V _ x -> return x
ELet _ a b -> do
x <- interpret' env a
- let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b
+ let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b
expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
@@ -84,18 +89,23 @@ interpret'Rec env = \case
ENil _ -> return ()
EInl _ _ e -> Left <$> interpret' env e
EInr _ _ e -> Right <$> interpret' env e
- ECase _ e a b -> interpret' env e >>= \case
- Left x -> interpret' (Value x `SCons` env) a
- Right y -> interpret' (Value y `SCons` env) b
+ ECase _ e a b ->
+ let STEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Left x -> interpret' (V t1 x `SCons` env) a
+ Right y -> interpret' (V t2 y `SCons` env) b
ENothing _ _ -> return Nothing
EJust _ e -> Just <$> interpret' env e
- EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
+ EMaybe _ a b e ->
+ let STMaybe t1 = typeOf e
+ in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
EConstArr _ _ _ v -> return v
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
- arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
EFold1Inner _ _ a b c -> do
- let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
+ let t = typeOf b
+ let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
@@ -131,14 +141,14 @@ interpret'Rec env = \case
-> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
- ECustom _ _ _ _ pr _ _ e1 e2 -> do
+ ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
e1' <- interpret' env e1
e2' <- interpret' env e2
- interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
+ interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
EWith _ t e1 e2 -> do
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
- interpret' (Value accum `SCons` env) e2
+ interpret' (V (STAccum t) accum `SCons` env) e2
EAccum _ t p e1 e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
diff --git a/test/Main.hs b/test/Main.hs
index 20b4ef0..3a6bc71 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -57,7 +57,7 @@ simplifyIters iters env | Dict <- envKnown env =
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- (out, grad) = interpretOpen False input dterm
+ (out, grad) = interpretOpen False env input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
@@ -232,7 +232,7 @@ compileTestGen name expr envGenerator =
in withCompiled env expr $ \fun ->
testProperty name $ property $ do
input <- forAllWith (showEnv env) envGenerator
- let resI = interpretOpen False input expr
+ let resI = interpretOpen False env input expr
resC <- liftIO $ fun input
let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
diff (TypedValue t resI) cmp (TypedValue t resC)
@@ -269,11 +269,11 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
testProperty "compile primal" $ property $ do
input <- forAllWith (showEnv env) envGenerator
- let outPrimalI = interpretOpen False input expr
+ let outPrimalI = interpretOpen False env input expr
outPrimalC <- liftIO $ primalfun input
diff outPrimalI (closeIsh' 1e-8) outPrimalC
- let outPrimalSI = interpretOpen False input exprS
+ let outPrimalSI = interpretOpen False env input exprS
outPrimalSC <- liftIO $ primalSfun input
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
@@ -285,7 +285,7 @@ adTestGenFwd env envGenerator exprS =
testProperty "compile fwdAD" $ property $ do
input <- forAllWith (showEnv env) envGenerator
dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
- let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS)
+ let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS)
(outDNC1, outDNC2) <- liftIO $ dnfun dinput
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
@@ -317,10 +317,10 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
- let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
- (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
- (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
scChad = tanEScalars env $ toTanE env input gradChad0
scChadS = tanEScalars env $ toTanE env input gradChadS
scSChad = tanEScalars env $ toTanE env input gradSChad0