summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs15
1 files changed, 9 insertions, 6 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 8da7598..5ec9dbc 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -451,6 +451,14 @@ term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $
idx0 $ sum1i $
build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i))
+term_idx_coprod :: Ex '[TVec (TEither R R)] R
+term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ case_ (#x ! pair nil #i)
+ (#a :-> #a * 2)
+ (#b :-> #b * 3)
+
term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
@@ -583,12 +591,7 @@ tests_AD = testGroup "AD"
let_ #p (#x ! pair nil #i) $
3 * fst_ #p + 2 * snd_ #p
- ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
- let_ #n (snd_ (shape #x)) $
- idx0 $ sum1i $ build1 #n $ #i :->
- case_ (#x ! pair nil #i)
- (#a :-> #a * 2)
- (#b :-> #b * 3)
+ ,adTest "idx-coprod" $ term_idx_coprod
,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $
let_ #n (snd_ (shape #arr)) $