summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-11 00:22:26 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-11 00:25:20 +0100
commit1abb0c11efd2ba650c0a20de8047efbde2cc6adf (patch)
treee60724dbdcb96ae4c5237989c90d8093ca772bf5
parent41f895bb9827f1f0e422e623879a08a0d2412f35 (diff)
test: Split adTestGen into one function per test case
This improves (compactifies) hedgehog output
-rw-r--r--src/ForwardAD.hs4
-rw-r--r--test/Main.hs157
2 files changed, 88 insertions, 73 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index e867d66..af35f91 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -85,6 +85,10 @@ tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
+tanEScalars SNil SNil = []
+tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs
+
unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =
diff --git a/test/Main.hs b/test/Main.hs
index 52bdbd0..7dbafab 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -206,79 +206,90 @@ adTestGen :: forall env. KnownEnv env
adTestGen name expr envGenerator =
let env = knownEnv @env
exprS = simplifyFix expr
- in
- withCompiled env expr $ \primalfun ->
- withCompiled env (simplifyFix expr) $ \primalSfun ->
- testGroupCollapse name
- [testProperty "compile primal" $ property $ do
- input <- forAllWith (showEnv env) envGenerator
-
- let outPrimalI = interpretOpen False input expr
- outPrimalC <- liftIO $ primalfun input
- diff outPrimalI (closeIsh' 1e-8) outPrimalC
-
- let outPrimalSI = interpretOpen False input exprS
- outPrimalSC <- liftIO $ primalSfun input
- diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
-
- ,withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
- testProperty "compile fwdAD" $ property $ do
- input <- forAllWith (showEnv env) envGenerator
- dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
- let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
- (outDNC1, outDNC2) <- liftIO $ dnfun dinput
- diff outDNI1 (closeIsh' 1e-8) outDNC1
- diff outDNI2 (closeIsh' 1e-8) outDNC2
-
- ,withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- testProperty "chad" $ property $ do
- annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
-
- let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
- dtermChadS = simplifyFix dtermChad0
- let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
- dtermSChadS = simplifyFix dtermSChad0
-
- -- pack Text for less GC pressure (these values are retained for some reason)
- diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
- diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
-
- input <- forAllWith (showEnv env) envGenerator
- outPrimal <- liftIO $ primalSfun input
-
- let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
- unpackGrad = unTup vUnpair (d2e env) . Value
-
- let scFwd = envScalars env $ gradientByForward fwdartifactC input
-
- let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
- (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
- (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
- scChad = envScalars env $ toTanE env input gradChad0
- scChadS = envScalars env $ toTanE env input gradChadS
- scSChad = envScalars env $ toTanE env input gradSChad0
- scSChadS = envScalars env $ toTanE env input gradSChadS
-
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
- -- annotate (ppExpr knownEnv expr)
- -- annotate (ppExpr env dtermChad0)
- -- annotate (ppExpr env dtermChadS)
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
- ]
-
- where
- envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
- envScalars SNil SNil = []
- envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+ in withCompiled env expr $ \primalfun ->
+ withCompiled env (simplifyFix expr) $ \primalSfun ->
+ testGroupCollapse name
+ [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
+ ,adTestGenFwd env envGenerator expr exprS
+ ,adTestGenChad env envGenerator expr exprS primalSfun]
+
+adTestGenPrimal :: SList STy env -> Gen (SList Value env)
+ -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
+ -> TestTree
+adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
+ testProperty "compile primal" $ property $ do
+ input <- forAllWith (showEnv env) envGenerator
+
+ let outPrimalI = interpretOpen False input expr
+ outPrimalC <- liftIO $ primalfun input
+ diff outPrimalI (closeIsh' 1e-8) outPrimalC
+
+ let outPrimalSI = interpretOpen False input exprS
+ outPrimalSC <- liftIO $ primalSfun input
+ diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
+
+adTestGenFwd :: SList STy env -> Gen (SList Value env)
+ -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> TestTree
+adTestGenFwd env envGenerator expr exprS =
+ withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
+ testProperty "compile fwdAD" $ property $ do
+ input <- forAllWith (showEnv env) envGenerator
+ dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
+ let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
+ (outDNC1, outDNC2) <- liftIO $ dnfun dinput
+ diff outDNI1 (closeIsh' 1e-8) outDNC1
+ diff outDNI2 (closeIsh' 1e-8) outDNC2
+
+adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
+ -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> (SList Value env -> IO Double)
+ -> TestTree
+adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+ withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
+ testProperty "chad" $ property $ do
+ annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
+
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+ dtermChadS = simplifyFix dtermChad0
+ let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+ dtermSChadS = simplifyFix dtermSChad0
+
+ -- pack Text for less GC pressure (these values are retained for some reason)
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
+ diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
+
+ input <- forAllWith (showEnv env) envGenerator
+ outPrimal <- liftIO $ primalSfun input
+
+ let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
+ unpackGrad = unTup vUnpair (d2e env) . Value
+
+ let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
+
+ let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+ (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+ (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
+ (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
+ scChad = tanEScalars env $ toTanE env input gradChad0
+ scChadS = tanEScalars env $ toTanE env input gradChadS
+ scSChad = tanEScalars env $ toTanE env input gradSChad0
+ scSChadS = tanEScalars env $ toTanE env input gradSChadS
+
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
+ -- annotate (ppExpr knownEnv expr)
+ -- annotate (ppExpr env dtermChad0)
+ -- annotate (ppExpr env dtermChadS)
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
+ diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())