diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/test/Main.hs b/test/Main.hs index 8da7598..5ec9dbc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -451,6 +451,14 @@ term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ idx0 $ sum1i $ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) +term_idx_coprod :: Ex '[TVec (TEither R R)] R +term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ build1 #n $ #i :-> + case_ (#x ! pair nil #i) + (#a :-> #a * 2) + (#b :-> #b * 3) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -583,12 +591,7 @@ tests_AD = testGroup "AD" let_ #p (#x ! pair nil #i) $ 3 * fst_ #p + 2 * snd_ #p - ,adTest "idx-coprod" $ fromNamed $ lambda @(TVec (TEither R R)) #x $ body $ - let_ #n (snd_ (shape #x)) $ - idx0 $ sum1i $ build1 #n $ #i :-> - case_ (#x ! pair nil #i) - (#a :-> #a * 2) - (#b :-> #b * 3) + ,adTest "idx-coprod" $ term_idx_coprod ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $ let_ #n (snd_ (shape #arr)) $ |