diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 56 |
1 files changed, 47 insertions, 9 deletions
diff --git a/test/Main.hs b/test/Main.hs index f3aec68..0a57cbf 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -449,11 +449,30 @@ gen_neural = do lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx + term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -516,22 +535,22 @@ tests_Compile = testGroup "Compile" ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) (pair 0.0 0.0) $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TMaybe (TPair R R)) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil @@ -570,9 +589,7 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum @@ -580,6 +597,27 @@ tests_AD = testGroup "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + + ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#x ! pair nil #i) $ + 3 * fst_ #p + 2 * snd_ #p + + ,adTest "idx-coprod" $ term_idx_coprod + + ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ + let_ #n (snd_ (shape #arr)) $ + let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $ + if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x)) + (pair (inr (3 * #x)) (exp #x)))) $ + idx0 $ sum1i $ build1 #n $ #i :-> + let_ #p (#b ! pair nil #i) $ + case_ (fst_ #p) + (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p) + (#b :-> #b * 4) + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x |