summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-04-18 12:53:43 +0200
committerTom Smeding <t.j.smeding@uu.nl>2025-04-18 12:53:43 +0200
commitbd5d0458017862b984b9caf0975c135d154e8515 (patch)
treed6306079efb457afd9d5cb52defe2b1a05c94a6e /src
parent0a9e6dfc1accf9dc0254f0c720f633dab6e71f42 (diff)
pretty: Print arguments of open expression
Diffstat (limited to 'src')
-rw-r--r--src/AST/Pretty.hs15
-rw-r--r--src/Example.hs2
-rw-r--r--src/ForwardAD.hs2
-rw-r--r--src/Interpreter.hs46
4 files changed, 42 insertions, 23 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 604133b..da4f391 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -10,8 +10,9 @@
module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where
import Control.Monad (ap)
-import Data.List (intersperse)
+import Data.List (intersperse, intercalate)
import Data.Functor.Const
+import qualified Data.Functor.Product as Product
import Data.String (fromString)
import Prettyprinter
import Prettyprinter.Render.String
@@ -67,8 +68,16 @@ genNameIfUsedIn = genNameIfUsedIn' "x"
pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
pprintExpr = putStrLn . ppExpr knownEnv
-ppExpr :: PrettyX x => SList f env -> Expr x env t -> String
-ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1)
+ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
+ppExpr senv e = render $ fst . flip runM 1 $ do
+ val <- mkVal senv
+ e' <- ppExpr' 0 val e
+ let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
+ return $ group $ flatAlt
+ (hang 2 $
+ ppString lam
+ <> hardline <> e')
+ (ppString lam <+> e')
where
mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
diff --git a/src/Example.hs b/src/Example.hs
index 4fa8d5a..3623d03 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -158,7 +158,7 @@ neuralGo =
simplifyN 20 $
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
- (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
+ (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
(primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
_ -> undefined
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index af35f91..b7036dd 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -223,7 +223,7 @@ data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (D
makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
makeFwdADArtifactInterp env expr =
let dexpr = dfwdDN expr
- in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
{-# NOINLINE makeFwdADArtifactCompile #-}
makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index ddc3479..572f2bd 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -24,6 +24,7 @@ import Control.Monad (foldM, join, when, forM_)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
+import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.IORef
import System.IO (hPutStrLn, stderr)
@@ -48,35 +49,39 @@ runAcM (AcM m) = unsafePerformIO m
acmDebugLog :: String -> AcM s ()
acmDebugLog s = AcM (hPutStrLn stderr s)
+data V t = V (STy t) (Rep t)
+
interpret :: Ex '[] t -> Rep t
-interpret = interpretOpen False SNil
+interpret = interpretOpen False SNil SNil
-- | Bool: whether to trace execution with debug prints (very verbose)
-interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t
-interpretOpen prints env e =
+interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env venv e =
runAcM $
let ?depth = 0
?prints = prints
- in interpret' env e
+ in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e
-interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int)
+ => SList V env -> Ex env t -> AcM s (Rep t)
interpret' env e = do
+ let tenv = slistMap (\(V t _) -> t) env
let dep = ?depth
let lenlimit = max 20 (100 - dep)
let replace a b = map (\c -> if c == a then b else c)
let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..."
| otherwise = replace '\n' ' ' s
- when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e)
res <- let ?depth = dep + 1 in interpret'Rec env e
when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
return res
-interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t)
interpret'Rec env = \case
- EVar _ _ i -> case slistIdx env i of Value x -> return x
+ EVar _ _ i -> case slistIdx env i of V _ x -> return x
ELet _ a b -> do
x <- interpret' env a
- let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b
+ let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b
expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
@@ -84,18 +89,23 @@ interpret'Rec env = \case
ENil _ -> return ()
EInl _ _ e -> Left <$> interpret' env e
EInr _ _ e -> Right <$> interpret' env e
- ECase _ e a b -> interpret' env e >>= \case
- Left x -> interpret' (Value x `SCons` env) a
- Right y -> interpret' (Value y `SCons` env) b
+ ECase _ e a b ->
+ let STEither t1 t2 = typeOf e
+ in interpret' env e >>= \case
+ Left x -> interpret' (V t1 x `SCons` env) a
+ Right y -> interpret' (V t2 y `SCons` env) b
ENothing _ _ -> return Nothing
EJust _ e -> Just <$> interpret' env e
- EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
+ EMaybe _ a b e ->
+ let STMaybe t1 = typeOf e
+ in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e
EConstArr _ _ _ v -> return v
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
- arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
EFold1Inner _ _ a b c -> do
- let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
+ let t = typeOf b
+ let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
@@ -131,14 +141,14 @@ interpret'Rec env = \case
-> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
- ECustom _ _ _ _ pr _ _ e1 e2 -> do
+ ECustom _ t1 t2 _ pr _ _ e1 e2 -> do
e1' <- interpret' env e1
e2' <- interpret' env e2
- interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
+ interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr
EWith _ t e1 e2 -> do
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
- interpret' (Value accum `SCons` env) e2
+ interpret' (V (STAccum t) accum `SCons` env) e2
EAccum _ t p e1 e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2