summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs29
1 files changed, 20 insertions, 9 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 1b83a2e..d79e63f 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -435,11 +435,22 @@ gen_neural = do
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+term_build0 :: Ex '[TArr N0 R] R
+term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $
+ idx0 $
+ build SZ (shape #x) $ #idx :-> #x ! #idx
+
term_build1_sum :: Ex '[TVec R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
+term_build1_idx :: Ex '[TVec R] R
+term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $
+ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i))
+
term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
@@ -502,22 +513,22 @@ tests_Compile = testGroup "Compile"
,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
with @(TPair R R) (pair 0.0 0.0) $ #ac :->
- let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
nil
,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
- with @(TMaybe (TPair R R)) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $
+ with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $
+ let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $
nil
,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
let_ #len (snd_ (shape #x)) $
with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
- let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac)
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac)
nil) $
let_ #_ (accum SAPHere nil #x #ac) $
nil
@@ -556,9 +567,7 @@ tests_AD = testGroup "AD"
,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
- ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
- idx0 $
- build SZ (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build0" term_build0
,adTest "build1-sum" term_build1_sum
@@ -566,6 +575,8 @@ tests_AD = testGroup "AD"
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build1-idx" term_build1_idx
+
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TMat R) #x $ body $
idx0 $ sum1i $ maximum1i #x