diff options
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/test/Main.hs b/test/Main.hs index b6f9f2b..5db7ea0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,6 +24,7 @@ import Hedgehog.Main import Array import AST import AST.Pretty +import AST.UnMonoid import CHAD.Top import CHAD.Types import qualified Example @@ -274,19 +275,9 @@ tests = checkParallel $ Group "AD" let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42) - ,("neural", adTestGen Example.neural $ do - let tR = STScal STF64 - let genLayer nin nout = - liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) - <*> genArray tR (ShNil `ShCons` nout) - nin <- Gen.integral (Range.linear 1 10) - n1 <- Gen.integral (Range.linear 1 10) - n2 <- Gen.integral (Range.linear 1 10) - input <- genArray tR (ShNil `ShCons` nin) - lay1 <- genLayer nin n1 - lay2 <- genLayer n1 n2 - lay3 <- genArray tR (ShNil `ShCons` n2) - return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + ,("neural", adTestGen Example.neural genNeural) + + ,("neural-unMonoid", adTestGen (unMonoid (simplifyFix Example.neural)) genNeural) ,("logsumexp", adTestTp (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ @@ -304,7 +295,11 @@ tests = checkParallel $ Group "AD" ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) + ,("gmm-wrong-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM) + ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) + + ,("gmm-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM) ] where genGMM = do @@ -330,5 +325,19 @@ tests = checkParallel $ Group "AD" Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil) + genNeural = do + let tR = STScal STF64 + let genLayer nin nout = + liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + <*> genArray tR (ShNil `ShCons` nout) + nin <- Gen.integral (Range.linear 1 10) + n1 <- Gen.integral (Range.linear 1 10) + n2 <- Gen.integral (Range.linear 1 10) + input <- genArray tR (ShNil `ShCons` nin) + lay1 <- genLayer nin n1 + lay2 <- genLayer n1 n2 + lay3 <- genArray tR (ShNil `ShCons` n2) + return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) + main :: IO () main = defaultMain [tests] |