summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-26 00:04:14 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-26 00:04:14 +0200
commitaf51ad3cdce90ac6afe4727c8713426624ebaecd (patch)
treef9cf215d8737c2fda66f94dd46f195a809865433 /test
parent6a0381f9c6cfc56ac805801bf4cefda8305ff055 (diff)
Debugging
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs82
1 files changed, 40 insertions, 42 deletions
diff --git a/test/Main.hs b/test/Main.hs
index e325b64..ab01e89 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -14,15 +14,12 @@ module Main where
import Data.Bifunctor
-- import qualified Data.Dependent.Map as DMap
-- import Data.Dependent.Map (DMap)
-import Data.Foldable (toList)
-import Data.List (intercalate, intersperse)
+import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main
-import Debug.Trace
-
import Array
import AST
import AST.Pretty
@@ -34,6 +31,7 @@ import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
+import Simplify
type family MapMerge env where
@@ -48,21 +46,33 @@ mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
-gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
-gradientByCHAD = \env term input ->
+primalEnv :: SList STy env' -> SList STy (D1E env')
+primalEnv SNil = SNil
+primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
+
+diffCHAD :: Int -> SList STy env -> Ex env (TScal TF64)
+ -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env)))
+diffCHAD = \simplIters env term ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
let descr = makeMergeDescr env
- dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
- input1 = toPrimalE env input
- (_out, grad) = interpretOpen input1 dterm
- in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $
- unTup vUnpair (d2e env) (Value grad)
+ in case envKnown (primalEnv env) of
+ Dict -> simplifyN simplIters $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
+-- In addition to the gradient, also returns the pretty-printed differentiated term.
+gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (D2E env))
+gradientByCHAD = \simplIters env term input ->
+ case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
+ (Refl, Refl) ->
+ let dterm = diffCHAD simplIters env term
+ input1 = toPrimalE env input
+ (_out, grad) = interpretOpen input1 dterm
+ in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad))
+ where
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
toPrimalE SNil SNil = SNil
toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
@@ -77,12 +87,9 @@ gradientByCHAD = \env term input ->
STScal _ -> id
STAccum{} -> error "Accumulators not allowed in input program"
- primalEnv :: SList STy env' -> SList STy (D1E env')
- primalEnv SNil = SNil
- primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
-
-gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
-gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
+-- In addition to the gradient, also returns the pretty-printed differentiated term.
+gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (TanE env))
+gradientByCHAD' = \simplIters env term input -> toTanE env input <$> gradientByCHAD simplIters env term input
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
@@ -183,26 +190,6 @@ genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
-- shapes <- DMap.traverseWithKey _ constrs
-- genEnvTemplateExact shapes env
-showValue :: Int -> STy t -> Rep t -> ShowS
-showValue _ STNil () = showString "()"
-showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
-showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
-showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
-showValue _ (STMaybe _) Nothing = showString "Nothing"
-showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
-showValue d (STArr _ t) arr = showParen (d > 10) $
- showString "arrayFromList " . showsPrec 11 (arrayShape arr)
- . showString " ["
- . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
- . showString "]"
-showValue _ (STScal sty) x = case sty of
- STF32 -> shows x
- STF64 -> shows x
- STI32 -> shows x
- STI64 -> shows x
- STBool -> shows x
-showValue _ STAccum{} _ = error "Cannot show accumulators"
-
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
where
@@ -224,15 +211,27 @@ adTestGen expr envGenerator = property $ do
let env = knownEnv @env
input <- forAllWith (showEnv env) envGenerator
let gradFwd = gradientByForward knownEnv expr input
- gradCHAD = gradientByCHAD' knownEnv expr input
+ (ppdterm, gradCHAD) = gradientByCHAD' 0 knownEnv expr input
+ (ppdterm_S, gradCHAD_S) = gradientByCHAD' 20 knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
- diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
+ scCHAD_S = envScalars env gradCHAD_S
+ annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
+ annotate (ppExpr knownEnv expr)
+ annotate ppdterm
+ annotate ppdterm_S
+ diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
+ diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_build1_sum = fromNamed $ lambda #x $ body $
+ idx0 $ sum1i $
+ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
+
tests :: IO Bool
tests = checkSequential $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
@@ -256,9 +255,8 @@ tests = checkSequential $ Group "AD"
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)
- ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $
- idx0 $ sum1i $
- build (SS SZ) (shape #x) $ #idx :-> #x ! #idx)
+ -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum
+ ,("build1-sum", adTest term_build1_sum)
,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
idx0 $ sum1i . sum1i $