diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-26 00:04:14 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-26 00:04:14 +0200 |
commit | af51ad3cdce90ac6afe4727c8713426624ebaecd (patch) | |
tree | f9cf215d8737c2fda66f94dd46f195a809865433 /test | |
parent | 6a0381f9c6cfc56ac805801bf4cefda8305ff055 (diff) |
Debugging
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 82 |
1 files changed, 40 insertions, 42 deletions
diff --git a/test/Main.hs b/test/Main.hs index e325b64..ab01e89 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -14,15 +14,12 @@ module Main where import Data.Bifunctor -- import qualified Data.Dependent.Map as DMap -- import Data.Dependent.Map (DMap) -import Data.Foldable (toList) -import Data.List (intercalate, intersperse) +import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Hedgehog.Main -import Debug.Trace - import Array import AST import AST.Pretty @@ -34,6 +31,7 @@ import ForwardAD import Interpreter import Interpreter.Rep import Language +import Simplify type family MapMerge env where @@ -48,21 +46,33 @@ mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env mapMergeOnlyMerge SNil = Refl mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl -gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) -gradientByCHAD = \env term input -> +primalEnv :: SList STy env' -> SList STy (D1E env') +primalEnv SNil = SNil +primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env + +diffCHAD :: Int -> SList STy env -> Ex env (TScal TF64) + -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env))) +diffCHAD = \simplIters env term -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> let descr = makeMergeDescr env - dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) - input1 = toPrimalE env input - (_out, grad) = interpretOpen input1 dterm - in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $ - unTup vUnpair (d2e env) (Value grad) + in case envKnown (primalEnv env) of + Dict -> simplifyN simplIters $ freezeRet descr (drev descr term) (EConst ext STF64 1.0) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) +-- In addition to the gradient, also returns the pretty-printed differentiated term. +gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (D2E env)) +gradientByCHAD = \simplIters env term input -> + case (mapMergeNoAccum env, mapMergeOnlyMerge env) of + (Refl, Refl) -> + let dterm = diffCHAD simplIters env term + input1 = toPrimalE env input + (_out, grad) = interpretOpen input1 dterm + in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad)) + where toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') toPrimalE SNil SNil = SNil toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp @@ -77,12 +87,9 @@ gradientByCHAD = \env term input -> STScal _ -> id STAccum{} -> error "Accumulators not allowed in input program" - primalEnv :: SList STy env' -> SList STy (D1E env') - primalEnv SNil = SNil - primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env - -gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) -gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) +-- In addition to the gradient, also returns the pretty-printed differentiated term. +gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (TanE env)) +gradientByCHAD' = \simplIters env term input -> toTanE env input <$> gradientByCHAD simplIters env term input where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil @@ -183,26 +190,6 @@ genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env -- shapes <- DMap.traverseWithKey _ constrs -- genEnvTemplateExact shapes env -showValue :: Int -> STy t -> Rep t -> ShowS -showValue _ STNil () = showString "()" -showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" -showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x -showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y -showValue _ (STMaybe _) Nothing = showString "Nothing" -showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x -showValue d (STArr _ t) arr = showParen (d > 10) $ - showString "arrayFromList " . showsPrec 11 (arrayShape arr) - . showString " [" - . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) - . showString "]" -showValue _ (STScal sty) x = case sty of - STF32 -> shows x - STF64 -> shows x - STI32 -> shows x - STI64 -> shows x - STBool -> shows x -showValue _ STAccum{} _ = error "Cannot show accumulators" - showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" where @@ -224,15 +211,27 @@ adTestGen expr envGenerator = property $ do let env = knownEnv @env input <- forAllWith (showEnv env) envGenerator let gradFwd = gradientByForward knownEnv expr input - gradCHAD = gradientByCHAD' knownEnv expr input + (ppdterm, gradCHAD) = gradientByCHAD' 0 knownEnv expr input + (ppdterm_S, gradCHAD_S) = gradientByCHAD' 20 knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD - diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd + scCHAD_S = envScalars env gradCHAD_S + annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) + annotate (ppExpr knownEnv expr) + annotate ppdterm + annotate ppdterm_S + diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S + diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs +term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_build1_sum = fromNamed $ lambda #x $ body $ + idx0 $ sum1i $ + build (SS SZ) (shape #x) $ #idx :-> #x ! #idx + tests :: IO Bool tests = checkSequential $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) @@ -256,9 +255,8 @@ tests = checkSequential $ Group "AD" idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) - ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $ - idx0 $ sum1i $ - build (SS SZ) (shape #x) $ #idx :-> #x ! #idx) + -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum + ,("build1-sum", adTest term_build1_sum) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ |