From d1b2e2c3a3cdaf49ff5e4bae6fe9b0612c3779c2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 00:00:11 +0200 Subject: Tests pass, should check if output is sensible --- test/Main.hs | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'test/Main.hs') diff --git a/test/Main.hs b/test/Main.hs index 1b83a2e..d79e63f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -435,11 +435,22 @@ gen_neural = do lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx + term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -502,22 +513,22 @@ tests_Compile = testGroup "Compile" ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) (pair 0.0 0.0) $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TMaybe (TPair R R)) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil @@ -556,9 +567,7 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum @@ -566,6 +575,8 @@ tests_AD = testGroup "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x -- cgit v1.2.3-70-g09d2 From 2b00a57f565a42b1079a071e2db630ba22c7120d Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 00:07:48 +0200 Subject: TODO deep zero in accum + fix warnings --- test/Main.hs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'test/Main.hs') diff --git a/test/Main.hs b/test/Main.hs index d79e63f..8da7598 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -577,6 +577,30 @@ tests_AD = testGroup "AD" ,adTest "build1-idx" term_build1_idx + ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#x ! pair nil #i) $ + 3 * fst_ #p + 2 * snd_ #p + + ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + + ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ + let_ #n (snd_ (shape #arr)) $ + let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $ + if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x)) + (pair (inr (3 * #x)) (exp #x)))) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#b ! pair nil #i) $ + case_ (fst_ #p) + (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p) + (#b :-> #b * 4) + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x -- cgit v1.2.3-70-g09d2 From fe80b31555c27f038b20eb84eb1e747781d7c76b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 10:11:12 +0200 Subject: Don't destroy effects in sparse plus --- src/AST/Sparse.hs | 20 ++++++++++++-------- src/CHAD.hs | 3 ++- test/Main.hs | 15 +++++++++------ 3 files changed, 23 insertions(+), 15 deletions(-) (limited to 'test/Main.hs') diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index f0a1f2a..0c5bdb0 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -66,6 +66,9 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g) withInj2 Noinj _ _ = Noinj withInj2 _ Noinj _ = Noinj +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + -- | This function produces quadratically-sized code in the presence of nested -- dynamic sparsity. TODO can this be improved? sparsePlusS @@ -77,16 +80,17 @@ sparsePlusS -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) -> r) -> r --- nil override -sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext) +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) -- simplifications sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> - k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b) + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> - k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext)) + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = let ta = applySparse sp1 (fromSMTy t) in @@ -144,13 +148,13 @@ sparsePlusS _ _ t sp1 sp2 k = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) -- handle absents -sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) sparsePlusS ST _ t SpAbsent sp2 k = - k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b) + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) -sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a) +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) sparsePlusS _ ST t sp1 SpAbsent k = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a) + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) -- double sparse yields sparse sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = diff --git a/src/CHAD.hs b/src/CHAD.hs index 3399de2..9a08457 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -404,7 +404,8 @@ subenvPlus :: SBool req1 -> SBool req2 -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r -subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext) +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> diff --git a/test/Main.hs b/test/Main.hs index 8da7598..5ec9dbc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -451,6 +451,14 @@ term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ idx0 $ sum1i $ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -583,12 +591,7 @@ tests_AD = testGroup "AD" let_ #p (#x ! pair nil #i) $ 3 * fst_ #p + 2 * snd_ #p - ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ - let_ #n (snd_ (shape #x)) $ - idx0 $ sum1i $ build1 #n $ #i :-> - case_ (#x ! pair nil #i) - (#a :-> #a * 2) - (#b :-> #b * 3) + ,adTest "idx-coprod" $ term_idx_coprod ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ let_ #n (snd_ (shape #arr)) $ -- cgit v1.2.3-70-g09d2 From a45bf0fd84d8e604613e9e557ae80143f1a41004 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 11:25:13 +0200 Subject: test: Test both default and accum configs --- test/Main.hs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'test/Main.hs') diff --git a/test/Main.hs b/test/Main.hs index 5ec9dbc..3847920 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -305,7 +305,9 @@ adTestGen name expr envGenerator = testGroupCollapse name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS - ,adTestGenChad env envGenerator expr exprS primalSfun] + ,testGroup "chad" + [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun + ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]] adTestGenPrimal :: SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R @@ -336,19 +338,19 @@ adTestGenFwd env envGenerator exprS = diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 -adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) +adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree -adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 - dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> - testProperty "chad" $ property $ do + testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- pack Text for less GC pressure (these values are retained for some reason) -- cgit v1.2.3-70-g09d2 From 6d25e87e6f703395038d23aaff225aa502283519 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 13:53:00 +0200 Subject: test: Diligently check UnMonoid correctness --- test/Main.hs | 58 +++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 23 deletions(-) (limited to 'test/Main.hs') diff --git a/test/Main.hs b/test/Main.hs index 3847920..0a57cbf 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -345,17 +345,21 @@ adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env = let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr dtermChadS = simplifyFix dtermChad0 + dtermChadSUS = simplifyFix $ unMonoid dtermChadS dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 + dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> + withCompiled env dtermSChadSUS $ \dcompSChadSUS -> testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - -- pack Text for less GC pressure (these values are retained for some reason) + -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0))) diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -365,17 +369,21 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS - - (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) - let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + + (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input) + let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) @@ -383,17 +391,21 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outCompSChadS closeIsh outPrimal + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outCompSChadSUS closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansCompSChadS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansCompSChadSUS closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) -- cgit v1.2.3-70-g09d2