From bd5d0458017862b984b9caf0975c135d154e8515 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 18 Apr 2025 12:53:43 +0200 Subject: pretty: Print arguments of open expression --- src/AST/Pretty.hs | 15 ++++++++++++--- src/Example.hs | 2 +- src/ForwardAD.hs | 2 +- src/Interpreter.hs | 46 ++++++++++++++++++++++++++++------------------ 4 files changed, 42 insertions(+), 23 deletions(-) (limited to 'src') diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 604133b..da4f391 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -10,8 +10,9 @@ module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where import Control.Monad (ap) -import Data.List (intersperse) +import Data.List (intersperse, intercalate) import Data.Functor.Const +import qualified Data.Functor.Product as Product import Data.String (fromString) import Prettyprinter import Prettyprinter.Render.String @@ -67,8 +68,16 @@ genNameIfUsedIn = genNameIfUsedIn' "x" pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO () pprintExpr = putStrLn . ppExpr knownEnv -ppExpr :: PrettyX x => SList f env -> Expr x env t -> String -ppExpr senv e = render $ fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) +ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String +ppExpr senv e = render $ fst . flip runM 1 $ do + val <- mkVal senv + e' <- ppExpr' 0 val e + let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." + return $ group $ flatAlt + (hang 2 $ + ppString lam + <> hardline <> e') + (ppString lam <+> e') where mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil diff --git a/src/Example.hs b/src/Example.hs index 4fa8d5a..3623d03 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -158,7 +158,7 @@ neuralGo = simplifyN 20 $ ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural - (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of + (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') _ -> undefined (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index af35f91..b7036dd 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -223,7 +223,7 @@ data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (D makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t makeFwdADArtifactInterp env expr = let dexpr = dfwdDN expr - in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr) + in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) {-# NOINLINE makeFwdADArtifactCompile #-} makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index ddc3479..572f2bd 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -24,6 +24,7 @@ import Control.Monad (foldM, join, when, forM_) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity +import qualified Data.Functor.Product as Product import Data.Int (Int64) import Data.IORef import System.IO (hPutStrLn, stderr) @@ -48,35 +49,39 @@ runAcM (AcM m) = unsafePerformIO m acmDebugLog :: String -> AcM s () acmDebugLog s = AcM (hPutStrLn stderr s) +data V t = V (STy t) (Rep t) + interpret :: Ex '[] t -> Rep t -interpret = interpretOpen False SNil +interpret = interpretOpen False SNil SNil -- | Bool: whether to trace execution with debug prints (very verbose) -interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t -interpretOpen prints env e = +interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t +interpretOpen prints env venv e = runAcM $ let ?depth = 0 ?prints = prints - in interpret' env e + in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e -interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) + => SList V env -> Ex env t -> AcM s (Rep t) interpret' env e = do + let tenv = slistMap (\(V t _) -> t) env let dep = ?depth let lenlimit = max 20 (100 - dep) let replace a b = map (\c -> if c == a then b else c) let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." | otherwise = replace '\n' ' ' s - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e) + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) res <- let ?depth = dep + 1 in interpret'Rec env e when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" return res -interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) interpret'Rec env = \case - EVar _ _ i -> case slistIdx env i of Value x -> return x + EVar _ _ i -> case slistIdx env i of V _ x -> return x ELet _ a b -> do x <- interpret' env a - let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b + let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b expr | False && trace (" " ++ takeWhile (not . isSpace) (show expr)) False -> undefined EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e @@ -84,18 +89,23 @@ interpret'Rec env = \case ENil _ -> return () EInl _ _ e -> Left <$> interpret' env e EInr _ _ e -> Right <$> interpret' env e - ECase _ e a b -> interpret' env e >>= \case - Left x -> interpret' (Value x `SCons` env) a - Right y -> interpret' (Value y `SCons` env) b + ECase _ e a b -> + let STEither t1 t2 = typeOf e + in interpret' env e >>= \case + Left x -> interpret' (V t1 x `SCons` env) a + Right y -> interpret' (V t2 y `SCons` env) b ENothing _ _ -> return Nothing EJust _ e -> Just <$> interpret' env e - EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e + EMaybe _ a b e -> + let STMaybe t1 = typeOf e + in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e EConstArr _ _ _ v -> return v EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) + arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) EFold1Inner _ _ a b c -> do - let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a + let t = typeOf b + let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr @@ -131,14 +141,14 @@ interpret'Rec env = \case -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e - ECustom _ _ _ _ pr _ _ e1 e2 -> do + ECustom _ t1 t2 _ pr _ _ e1 e2 -> do e1' <- interpret' env e1 e2' <- interpret' env e2 - interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr + interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr EWith _ t e1 e2 -> do initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> - interpret' (Value accum `SCons` env) e2 + interpret' (V (STAccum t) accum `SCons` env) e2 EAccum _ t p e1 e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 -- cgit v1.2.3-70-g09d2