diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-11 00:22:26 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-11 00:25:20 +0100 |
commit | 1abb0c11efd2ba650c0a20de8047efbde2cc6adf (patch) | |
tree | e60724dbdcb96ae4c5237989c90d8093ca772bf5 | |
parent | 41f895bb9827f1f0e422e623879a08a0d2412f35 (diff) |
test: Split adTestGen into one function per test case
This improves (compactifies) hedgehog output
-rw-r--r-- | src/ForwardAD.hs | 4 | ||||
-rw-r--r-- | test/Main.hs | 157 |
2 files changed, 88 insertions, 73 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index e867d66..af35f91 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -85,6 +85,10 @@ tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] +tanEScalars SNil SNil = [] +tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs + unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) unzipDN STNil _ = ((), ()) unzipDN (STPair a b) (d1, d2) = diff --git a/test/Main.hs b/test/Main.hs index 52bdbd0..7dbafab 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -206,79 +206,90 @@ adTestGen :: forall env. KnownEnv env adTestGen name expr envGenerator = let env = knownEnv @env exprS = simplifyFix expr - in - withCompiled env expr $ \primalfun -> - withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name - [testProperty "compile primal" $ property $ do - input <- forAllWith (showEnv env) envGenerator - - let outPrimalI = interpretOpen False input expr - outPrimalC <- liftIO $ primalfun input - diff outPrimalI (closeIsh' 1e-8) outPrimalC - - let outPrimalSI = interpretOpen False input exprS - outPrimalSC <- liftIO $ primalSfun input - diff outPrimalSI (closeIsh' 1e-8) outPrimalSC - - ,withCompiled (dne env) (dfwdDN exprS) $ \dnfun -> - testProperty "compile fwdAD" $ property $ do - input <- forAllWith (showEnv env) envGenerator - dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input - let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr) - (outDNC1, outDNC2) <- liftIO $ dnfun dinput - diff outDNI1 (closeIsh' 1e-8) outDNC1 - diff outDNI2 (closeIsh' 1e-8) outDNC2 - - ,withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - testProperty "chad" $ property $ do - annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr - dtermChadS = simplifyFix dtermChad0 - let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS - dtermSChadS = simplifyFix dtermSChad0 - - -- pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) - - input <- forAllWith (showEnv env) envGenerator - outPrimal <- liftIO $ primalSfun input - - let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) - unpackGrad = unTup vUnpair (d2e env) . Value - - let scFwd = envScalars env $ gradientByForward fwdartifactC input - - let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS - scChad = envScalars env $ toTanE env input gradChad0 - scChadS = envScalars env $ toTanE env input gradChadS - scSChad = envScalars env $ toTanE env input gradSChad0 - scSChadS = envScalars env $ toTanE env input gradSChadS - - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) - -- annotate (ppExpr knownEnv expr) - -- annotate (ppExpr env dtermChad0) - -- annotate (ppExpr env dtermChadS) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd - diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd - ] - - where - envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] - envScalars SNil SNil = [] - envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs + in withCompiled env expr $ \primalfun -> + withCompiled env (simplifyFix expr) $ \primalSfun -> + testGroupCollapse name + [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun + ,adTestGenFwd env envGenerator expr exprS + ,adTestGenChad env envGenerator expr exprS primalSfun] + +adTestGenPrimal :: SList STy env -> Gen (SList Value env) + -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> (SList Value env -> IO Double) -> (SList Value env -> IO Double) + -> TestTree +adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = + testProperty "compile primal" $ property $ do + input <- forAllWith (showEnv env) envGenerator + + let outPrimalI = interpretOpen False input expr + outPrimalC <- liftIO $ primalfun input + diff outPrimalI (closeIsh' 1e-8) outPrimalC + + let outPrimalSI = interpretOpen False input exprS + outPrimalSC <- liftIO $ primalSfun input + diff outPrimalSI (closeIsh' 1e-8) outPrimalSC + +adTestGenFwd :: SList STy env -> Gen (SList Value env) + -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> TestTree +adTestGenFwd env envGenerator expr exprS = + withCompiled (dne env) (dfwdDN exprS) $ \dnfun -> + testProperty "compile fwdAD" $ property $ do + input <- forAllWith (showEnv env) envGenerator + dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input + let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr) + (outDNC1, outDNC2) <- liftIO $ dnfun dinput + diff outDNI1 (closeIsh' 1e-8) outDNC1 + diff outDNI2 (closeIsh' 1e-8) outDNC2 + +adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) + -> Ex env (TScal TF64) -> Ex env (TScal TF64) + -> (SList Value env -> IO Double) + -> TestTree +adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = + withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> + testProperty "chad" $ property $ do + annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) + + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr + dtermChadS = simplifyFix dtermChad0 + let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermSChadS = simplifyFix dtermSChad0 + + -- pack Text for less GC pressure (these values are retained for some reason) + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) + + input <- forAllWith (showEnv env) envGenerator + outPrimal <- liftIO $ primalSfun input + + let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) + unpackGrad = unTup vUnpair (d2e env) . Value + + let scFwd = tanEScalars env $ gradientByForward fwdartifactC input + + let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 + (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS + (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0 + (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS + scChad = tanEScalars env $ toTanE env input gradChad0 + scChadS = tanEScalars env $ toTanE env input gradChadS + scSChad = tanEScalars env $ toTanE env input gradSChad0 + scSChadS = tanEScalars env $ toTanE env input gradSChadS + + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) + -- annotate (ppExpr knownEnv expr) + -- annotate (ppExpr env dtermChad0) + -- annotate (ppExpr env dtermChadS) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd + diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd + diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd + diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) |