summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-18 00:07:48 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-18 00:13:32 +0200
commit2b00a57f565a42b1079a071e2db630ba22c7120d (patch)
treea7039adddc9c56b6d5791ec61baa9fb8aa564a0b
parentd1b2e2c3a3cdaf49ff5e4bae6fe9b0612c3779c2 (diff)
TODO deep zero in accum + fix warnings
-rw-r--r--test/Main.hs24
1 files changed, 24 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs
index d79e63f..8da7598 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -577,6 +577,30 @@ tests_AD = testGroup "AD"
,adTest "build1-idx" term_build1_idx
+ ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ let_ #p (#x ! pair nil #i) $
+ 3 * fst_ #p + 2 * snd_ #p
+
+ ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ case_ (#x ! pair nil #i)
+ (#a :-> #a * 2)
+ (#b :-> #b * 3)
+
+ ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $
+ let_ #n (snd_ (shape #arr)) $
+ let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $
+ if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x))
+ (pair (inr (3 * #x)) (exp #x)))) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ let_ #p (#b ! pair nil #i) $
+ case_ (fst_ #p)
+ (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p)
+ (#b :-> #b * 4)
+
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TMat R) #x $ body $
idx0 $ sum1i $ maximum1i #x