summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-01 10:15:29 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-01 10:15:29 +0100
commita018bcf4393a4ddac7a76ca86b3409a669e59f48 (patch)
treed57ca9747180261be8cfbb0d35b31987560b99b6
parentefae6885dae62d0525e1eb238967dc817c4df22d (diff)
test: Pull term_mulmatvec out into top-level
-rw-r--r--test/Main.hs18
1 files changed, 10 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 9ab09c5..ec23eaf 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -235,6 +235,15 @@ term_sparse = fromNamed $ lambda #inp $ body $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
+term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
+term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
+ idx0 $ sum1i $
+ let_ #hei (snd_ (fst_ (shape #mat))) $
+ let_ #wid (snd_ (shape #mat)) $
+ build1 #hei $ #i :->
+ idx0 (sum1i (build1 #wid $ #j :->
+ #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
+
tests :: TestTree
tests = testGroup "AD"
[adTest "id" $ fromNamed $ lambda #x $ body $ #x
@@ -284,14 +293,7 @@ tests = testGroup "AD"
let_ #m (maximum1i #vec) $
log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
- ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) $
- fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
- idx0 $ sum1i $
- let_ #hei (snd_ (fst_ (shape #mat))) $
- let_ #wid (snd_ (shape #mat)) $
- build1 #hei $ #i :->
- idx0 (sum1i (build1 #wid $ #j :->
- #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
+ ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec
,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM