diff options
-rw-r--r-- | src/Example.hs | 2 | ||||
-rw-r--r-- | src/ForwardAD.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 28 | ||||
-rw-r--r-- | test/Main.hs | 2 |
4 files changed, 21 insertions, 13 deletions
diff --git a/src/Example.hs b/src/Example.hs index 6e8069c..0bc18fb 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -198,7 +198,7 @@ neuralGo = freezeRet mergeDescr (drev mergeDescr neural) (EConst ext STF64 1.0) - (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv + (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen False argument revderiv (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 in trace (formatter (ppExpr knownEnv revderiv)) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 86d2fb0..67d22dd 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -215,5 +215,5 @@ drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SLis drevByFwd env expr input dres = let outty = typeOf expr in dnOnehotEnvs env input $ \dnInput -> - let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr)) + let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) in dotprodTan outty outtan dres diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3d6f33d..2c63b24 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -19,7 +19,8 @@ module Interpreter ( Value(..), ) where -import Control.Monad (foldM, join) +import Control.Monad (foldM, join, when) +import Data.Bifunctor (bimap) import Data.Char (isSpace) import Data.Kind (Type) import Data.Int (Int64) @@ -35,7 +36,6 @@ import AST.Pretty import CHAD.Types import Data import Interpreter.Rep -import Data.Bifunctor (bimap) newtype AcM s a = AcM { unAcM :: IO a } @@ -48,25 +48,33 @@ acmDebugLog :: String -> AcM s () acmDebugLog s = AcM (hPutStrLn stderr s) interpret :: Ex '[] t -> Rep t -interpret = interpretOpen SNil +interpret = interpretOpen False SNil -interpretOpen :: SList Value env -> Ex env t -> Rep t -interpretOpen env e = runAcM (let ?depth = 0 in interpret' env e) +-- | Bool: whether to trace execution with debug prints (very verbose) +interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t +interpretOpen prints env e = + runAcM $ + let ?depth = 0 + ?prints = prints + in interpret' env e -interpret' :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) interpret' env e = do let dep = ?depth - acmDebugLog $ replicate dep ' ' ++ "ev: " ++ ppExpr env e + let lenlimit = max 20 (100 - dep) + let trunc s | length s > lenlimit = take (lenlimit - 3) s ++ "..." + | otherwise = s + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e) res <- let ?depth = dep + 1 in interpret'Rec env e - acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" return res -interpret'Rec :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) interpret'Rec env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a - interpret' (Value x `SCons` env) b + let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e diff --git a/test/Main.hs b/test/Main.hs index ab01e89..c746807 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -70,7 +70,7 @@ gradientByCHAD = \simplIters env term input -> (Refl, Refl) -> let dterm = diffCHAD simplIters env term input1 = toPrimalE env input - (_out, grad) = interpretOpen input1 dterm + (_out, grad) = interpretOpen False input1 dterm in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad)) where toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') |