From 1abb0c11efd2ba650c0a20de8047efbde2cc6adf Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 11 Mar 2025 00:22:26 +0100
Subject: test: Split adTestGen into one function per test case

This improves (compactifies) hedgehog output
---
 test/Main.hs | 157 ++++++++++++++++++++++++++++++++---------------------------
 1 file changed, 84 insertions(+), 73 deletions(-)

(limited to 'test/Main.hs')

diff --git a/test/Main.hs b/test/Main.hs
index 52bdbd0..7dbafab 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -206,79 +206,90 @@ adTestGen :: forall env. KnownEnv env
 adTestGen name expr envGenerator =
   let env = knownEnv @env
       exprS = simplifyFix expr
-  in
-  withCompiled env expr $ \primalfun ->
-  withCompiled env (simplifyFix expr) $ \primalSfun ->
-  testGroupCollapse name
-    [testProperty "compile primal" $ property $ do
-       input <- forAllWith (showEnv env) envGenerator
-
-       let outPrimalI = interpretOpen False input expr
-       outPrimalC <- liftIO $ primalfun input
-       diff outPrimalI (closeIsh' 1e-8) outPrimalC
-
-       let outPrimalSI = interpretOpen False input exprS
-       outPrimalSC <- liftIO $ primalSfun input
-       diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
-
-    ,withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
-     testProperty "compile fwdAD" $ property $ do
-       input <- forAllWith (showEnv env) envGenerator
-       dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
-       let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
-       (outDNC1, outDNC2) <- liftIO $ dnfun dinput
-       diff outDNI1 (closeIsh' 1e-8) outDNC1
-       diff outDNI2 (closeIsh' 1e-8) outDNC2
-
-    ,withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
-     testProperty "chad" $ property $ do
-       annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
-
-       let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
-           dtermChadS = simplifyFix dtermChad0
-       let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
-           dtermSChadS = simplifyFix dtermSChad0
-
-       -- pack Text for less GC pressure (these values are retained for some reason)
-       diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
-       diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
-
-       input <- forAllWith (showEnv env) envGenerator
-       outPrimal <- liftIO $ primalSfun input
-
-       let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
-           unpackGrad = unTup vUnpair (d2e env) . Value
-
-       let scFwd = envScalars env $ gradientByForward fwdartifactC input
-
-       let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
-           (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
-           (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
-           (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
-           scChad = envScalars env $ toTanE env input gradChad0
-           scChadS = envScalars env $ toTanE env input gradChadS
-           scSChad = envScalars env $ toTanE env input gradSChad0
-           scSChadS = envScalars env $ toTanE env input gradSChadS
-
-       -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
-       -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-       -- annotate (ppExpr knownEnv expr)
-       -- annotate (ppExpr env dtermChad0)
-       -- annotate (ppExpr env dtermChadS)
-       diff outChad0 closeIsh outPrimal
-       diff outChadS closeIsh outPrimal
-       diff outSChad0 closeIsh outPrimal
-       diff outSChadS closeIsh outPrimal
-       diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
-       diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
-       diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
-       diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
-    ]
-
-  where
-    envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
-    envScalars SNil SNil = []
-    envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+  in withCompiled env expr $ \primalfun ->
+     withCompiled env (simplifyFix expr) $ \primalSfun ->
+     testGroupCollapse name
+       [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
+       ,adTestGenFwd env envGenerator expr exprS
+       ,adTestGenChad env envGenerator expr exprS primalSfun]
+
+adTestGenPrimal :: SList STy env -> Gen (SList Value env)
+                -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+                -> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
+                -> TestTree
+adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
+  testProperty "compile primal" $ property $ do
+    input <- forAllWith (showEnv env) envGenerator
+
+    let outPrimalI = interpretOpen False input expr
+    outPrimalC <- liftIO $ primalfun input
+    diff outPrimalI (closeIsh' 1e-8) outPrimalC
+
+    let outPrimalSI = interpretOpen False input exprS
+    outPrimalSC <- liftIO $ primalSfun input
+    diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
+
+adTestGenFwd :: SList STy env -> Gen (SList Value env)
+             -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+             -> TestTree
+adTestGenFwd env envGenerator expr exprS =
+  withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
+    testProperty "compile fwdAD" $ property $ do
+      input <- forAllWith (showEnv env) envGenerator
+      dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
+      let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
+      (outDNC1, outDNC2) <- liftIO $ dnfun dinput
+      diff outDNI1 (closeIsh' 1e-8) outDNC1
+      diff outDNI2 (closeIsh' 1e-8) outDNC2
+
+adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
+              -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+              -> (SList Value env -> IO Double)
+              -> TestTree
+adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+  withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
+    testProperty "chad" $ property $ do
+      annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
+
+      let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+          dtermChadS = simplifyFix dtermChad0
+      let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+          dtermSChadS = simplifyFix dtermSChad0
+
+      -- pack Text for less GC pressure (these values are retained for some reason)
+      diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
+      diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
+
+      input <- forAllWith (showEnv env) envGenerator
+      outPrimal <- liftIO $ primalSfun input
+
+      let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
+          unpackGrad = unTup vUnpair (d2e env) . Value
+
+      let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
+
+      let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+          (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+          (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
+          (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
+          scChad = tanEScalars env $ toTanE env input gradChad0
+          scChadS = tanEScalars env $ toTanE env input gradChadS
+          scSChad = tanEScalars env $ toTanE env input gradSChad0
+          scSChadS = tanEScalars env $ toTanE env input gradSChadS
+
+      -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
+      -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
+      -- annotate (ppExpr knownEnv expr)
+      -- annotate (ppExpr env dtermChad0)
+      -- annotate (ppExpr env dtermChadS)
+      diff outChad0 closeIsh outPrimal
+      diff outChadS closeIsh outPrimal
+      diff outSChad0 closeIsh outPrimal
+      diff outSChadS closeIsh outPrimal
+      diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
+      diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+      diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
+      diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
 
 withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
 withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
-- 
cgit v1.2.3-70-g09d2