diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 19:54:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 19:54:53 +0100 |
commit | 3e266262ebe65bd5d775711b4d05bc9670a38a47 (patch) | |
tree | bf0fff187e53adb8a4f45b3d7c70c97566c1e141 /test/Main.hs | |
parent | 40a0abca1cedcdd930bb33d1874b7922443e5a8c (diff) |
UnMonoid
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/test/Main.hs b/test/Main.hs index b6f9f2b..5db7ea0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,6 +24,7 @@ import Hedgehog.Main import Array import AST import AST.Pretty +import AST.UnMonoid import CHAD.Top import CHAD.Types import qualified Example @@ -274,19 +275,9 @@ tests = checkParallel $ Group "AD" let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42) - ,("neural", adTestGen Example.neural $ do - let tR = STScal STF64 - let genLayer nin nout = - liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) - <*> genArray tR (ShNil `ShCons` nout) - nin <- Gen.integral (Range.linear 1 10) - n1 <- Gen.integral (Range.linear 1 10) - n2 <- Gen.integral (Range.linear 1 10) - input <- genArray tR (ShNil `ShCons` nin) - lay1 <- genLayer nin n1 - lay2 <- genLayer n1 n2 - lay3 <- genArray tR (ShNil `ShCons` n2) - return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + ,("neural", adTestGen Example.neural genNeural) + + ,("neural-unMonoid", adTestGen (unMonoid (simplifyFix Example.neural)) genNeural) ,("logsumexp", adTestTp (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ @@ -304,7 +295,11 @@ tests = checkParallel $ Group "AD" ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) + ,("gmm-wrong-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM) + ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) + + ,("gmm-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM) ] where genGMM = do @@ -330,5 +325,19 @@ tests = checkParallel $ Group "AD" Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil) + genNeural = do + let tR = STScal STF64 + let genLayer nin nout = + liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + <*> genArray tR (ShNil `ShCons` nout) + nin <- Gen.integral (Range.linear 1 10) + n1 <- Gen.integral (Range.linear 1 10) + n2 <- Gen.integral (Range.linear 1 10) + input <- genArray tR (ShNil `ShCons` nin) + lay1 <- genLayer nin n1 + lay2 <- genLayer n1 n2 + lay3 <- genArray tR (ShNil `ShCons` n2) + return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) + main :: IO () main = defaultMain [tests] |