summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-04 23:30:16 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-04 23:30:16 +0100
commit27491702baaffcd3ce8bef9ca8d06ee3b453540b (patch)
tree9e59444ffd283406517e5a77fe5e14a3f614b0dd
parent89b78d480a88559a8a9064eeafa60af345db4f2d (diff)
Add regression test for HEAD^
-rw-r--r--src/Data.hs4
-rw-r--r--test/Main.hs32
2 files changed, 28 insertions, 8 deletions
diff --git a/src/Data.hs b/src/Data.hs
index 00790fe..155eeb3 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -43,6 +43,10 @@ unSList :: (forall t. f t -> a) -> SList f list -> [a]
unSList _ SNil = []
unSList f (x `SCons` l) = f x : unSList f l
+showSList :: (forall t. Int -> f t -> String) -> SList f list -> String
+showSList _ SNil = "SNil"
+showSList f (x `SCons` l) = f 11 x ++ " `SCons` " ++ showSList f l
+
sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
sappend SNil l = l
sappend (SCons x xs) l = SCons x (sappend xs l)
diff --git a/test/Main.hs b/test/Main.hs
index 9ff82a1..d488ce5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -14,6 +14,7 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
+-- import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
@@ -199,20 +200,24 @@ adTestGen name expr envGenerator =
input <- forAllWith (showEnv env) envGenerator
- let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)
- convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
+ let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
+ unpackGrad = unTup vUnpair (d2e env) . Value
let outPrimalI = interpretOpen False input expr
outPrimal <- liftIO $ getprimalfun >>= ($ input)
diff outPrimal (closeIsh' 1e-8) outPrimalI
- let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
- scChad = envScalars env gradChad0
- scChadS = envScalars env gradChadS
+ let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+ (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+ gradChad0' = toTanE env input gradChad0
+ gradChadS' = toTanE env input gradChadS
+ scChad = envScalars env gradChad0'
+ scChadS = envScalars env gradChadS'
gradFwd = gradientByForward knownEnv expr input
scFwd = envScalars env gradFwd
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
@@ -248,6 +253,14 @@ term_sparse = fromNamed $ lambda #inp $ body $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
+term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_regression_simpl1 = fromNamed $ lambda #q $ body $
+ idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
+ let_ #j (snd_ #idx) $
+ if_ (#j .== 0)
+ (#q ! pair nil 0)
+ (if_ (#j .== #j) 1.0 2.0)
+
term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
@@ -309,11 +322,14 @@ tests = testGroup "AD"
idx0 $ sum1i $ minimum1i #x
,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
- let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
- 42
+ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
+ 42
,adTestTp "sparse" (C "" 5) term_sparse
+ -- Regression test for a simplifier bug (89b78d4)
+ ,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1
+
,adTestGen "neural" Example.neural genNeural
,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural