summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Example.hs2
-rw-r--r--src/ForwardAD.hs2
-rw-r--r--src/Interpreter.hs28
-rw-r--r--test/Main.hs2
4 files changed, 21 insertions, 13 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 6e8069c..0bc18fb 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -198,7 +198,7 @@ neuralGo =
freezeRet mergeDescr
(drev mergeDescr neural)
(EConst ext STF64 1.0)
- (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv
+ (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen False argument revderiv
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
in trace (formatter (ppExpr knownEnv revderiv)) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 86d2fb0..67d22dd 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -215,5 +215,5 @@ drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SLis
drevByFwd env expr input dres =
let outty = typeOf expr
in dnOnehotEnvs env input $ \dnInput ->
- let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr))
+ let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
in dotprodTan outty outtan dres
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 3d6f33d..2c63b24 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -19,7 +19,8 @@ module Interpreter (
Value(..),
) where
-import Control.Monad (foldM, join)
+import Control.Monad (foldM, join, when)
+import Data.Bifunctor (bimap)
import Data.Char (isSpace)
import Data.Kind (Type)
import Data.Int (Int64)
@@ -35,7 +36,6 @@ import AST.Pretty
import CHAD.Types
import Data
import Interpreter.Rep
-import Data.Bifunctor (bimap)
newtype AcM s a = AcM { unAcM :: IO a }
@@ -48,25 +48,33 @@ acmDebugLog :: String -> AcM s ()
acmDebugLog s = AcM (hPutStrLn stderr s)
interpret :: Ex '[] t -> Rep t
-interpret = interpretOpen SNil
+interpret = interpretOpen False SNil
-interpretOpen :: SList Value env -> Ex env t -> Rep t
-interpretOpen env e = runAcM (let ?depth = 0 in interpret' env e)
+-- | Bool: whether to trace execution with debug prints (very verbose)
+interpretOpen :: Bool -> SList Value env -> Ex env t -> Rep t
+interpretOpen prints env e =
+ runAcM $
+ let ?depth = 0
+ ?prints = prints
+ in interpret' env e
-interpret' :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env e = do
let dep = ?depth
- acmDebugLog $ replicate dep ' ' ++ "ev: " ++ ppExpr env e
+ let lenlimit = max 20 (100 - dep)
+ let trunc s | length s > lenlimit = take (lenlimit - 3) s ++ "..."
+ | otherwise = s
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr env e)
res <- let ?depth = dep + 1 in interpret'Rec env e
- acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
+ when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
return res
-interpret'Rec :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
interpret'Rec env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
ELet _ a b -> do
x <- interpret' env a
- interpret' (Value x `SCons` env) b
+ let ?depth = ?depth - 1 in interpret' (Value x `SCons` env) b
expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
diff --git a/test/Main.hs b/test/Main.hs
index ab01e89..c746807 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -70,7 +70,7 @@ gradientByCHAD = \simplIters env term input ->
(Refl, Refl) ->
let dterm = diffCHAD simplIters env term
input1 = toPrimalE env input
- (_out, grad) = interpretOpen input1 dterm
+ (_out, grad) = interpretOpen False input1 dterm
in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad))
where
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')