summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-12-12 11:44:16 +0100
committerTom Smeding <t.j.smeding@uu.nl>2024-12-12 11:44:16 +0100
commitfad10d5a218f935d47e8b9dc41256a30b4ec540d (patch)
treefdbfa5049306025d06147e523338ffeb18c2c916
parenta2d7ddd2230b7f42fe46eb33ea6dee8eb7080fdc (diff)
Add sparse test
-rw-r--r--test/Main.hs11
1 files changed, 11 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 5db7ea0..b234aa2 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -237,6 +237,15 @@ term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
+term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_sparse = fromNamed $ lambda #inp $ body $
+ let_ #n (snd_ (shape #inp)) $
+ let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
+ let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $
+ let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $
+ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
+ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
+
tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
@@ -275,6 +284,8 @@ tests = checkParallel $ Group "AD"
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42)
+ ,("sparse", adTestTp (C "" 5) term_sparse)
+
,("neural", adTestGen Example.neural genNeural)
,("neural-unMonoid", adTestGen (unMonoid (simplifyFix Example.neural)) genNeural)