summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs35
1 files changed, 22 insertions, 13 deletions
diff --git a/test/Main.hs b/test/Main.hs
index b6f9f2b..5db7ea0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -24,6 +24,7 @@ import Hedgehog.Main
import Array
import AST
import AST.Pretty
+import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import qualified Example
@@ -274,19 +275,9 @@ tests = checkParallel $ Group "AD"
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42)
- ,("neural", adTestGen Example.neural $ do
- let tR = STScal STF64
- let genLayer nin nout =
- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
- <*> genArray tR (ShNil `ShCons` nout)
- nin <- Gen.integral (Range.linear 1 10)
- n1 <- Gen.integral (Range.linear 1 10)
- n2 <- Gen.integral (Range.linear 1 10)
- input <- genArray tR (ShNil `ShCons` nin)
- lay1 <- genLayer nin n1
- lay2 <- genLayer n1 n2
- lay3 <- genArray tR (ShNil `ShCons` n2)
- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
+ ,("neural", adTestGen Example.neural genNeural)
+
+ ,("neural-unMonoid", adTestGen (unMonoid (simplifyFix Example.neural)) genNeural)
,("logsumexp", adTestTp (C "" 1) $
fromNamed $ lambda @(TArr N1 _) #vec $ body $
@@ -304,7 +295,11 @@ tests = checkParallel $ Group "AD"
,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM)
+ ,("gmm-wrong-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM)
+
,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM)
+
+ ,("gmm-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM)
]
where
genGMM = do
@@ -330,5 +325,19 @@ tests = checkParallel $ Group "AD"
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil)
+ genNeural = do
+ let tR = STScal STF64
+ let genLayer nin nout =
+ liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
+ <*> genArray tR (ShNil `ShCons` nout)
+ nin <- Gen.integral (Range.linear 1 10)
+ n1 <- Gen.integral (Range.linear 1 10)
+ n2 <- Gen.integral (Range.linear 1 10)
+ input <- genArray tR (ShNil `ShCons` nin)
+ lay1 <- genLayer nin n1
+ lay2 <- genLayer n1 n2
+ lay3 <- genArray tR (ShNil `ShCons` n2)
+ return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+
main :: IO ()
main = defaultMain [tests]