summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs32
1 files changed, 24 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 9ff82a1..d488ce5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -14,6 +14,7 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
+-- import qualified Data.Functor.Product as Product
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
@@ -199,20 +200,24 @@ adTestGen name expr envGenerator =
input <- forAllWith (showEnv env) envGenerator
- let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)
- convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
+ let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
+ unpackGrad = unTup vUnpair (d2e env) . Value
let outPrimalI = interpretOpen False input expr
outPrimal <- liftIO $ getprimalfun >>= ($ input)
diff outPrimal (closeIsh' 1e-8) outPrimalI
- let (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
- scChad = envScalars env gradChad0
- scChadS = envScalars env gradChadS
+ let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+ (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+ gradChad0' = toTanE env input gradChad0
+ gradChadS' = toTanE env input gradChadS
+ scChad = envScalars env gradChad0'
+ scChadS = envScalars env gradChadS'
gradFwd = gradientByForward knownEnv expr input
scFwd = envScalars env gradFwd
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
@@ -248,6 +253,14 @@ term_sparse = fromNamed $ lambda #inp $ body $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
+term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_regression_simpl1 = fromNamed $ lambda #q $ body $
+ idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
+ let_ #j (snd_ #idx) $
+ if_ (#j .== 0)
+ (#q ! pair nil 0)
+ (if_ (#j .== #j) 1.0 2.0)
+
term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
@@ -309,11 +322,14 @@ tests = testGroup "AD"
idx0 $ sum1i $ minimum1i #x
,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
- let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
- 42
+ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
+ 42
,adTestTp "sparse" (C "" 5) term_sparse
+ -- Regression test for a simplifier bug (89b78d4)
+ ,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1
+
,adTestGen "neural" Example.neural genNeural
,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural