diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-01 10:15:29 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-01 10:15:29 +0100 |
commit | a018bcf4393a4ddac7a76ca86b3409a669e59f48 (patch) | |
tree | d57ca9747180261be8cfbb0d35b31987560b99b6 | |
parent | efae6885dae62d0525e1eb238967dc817c4df22d (diff) |
test: Pull term_mulmatvec out into top-level
-rw-r--r-- | test/Main.hs | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/test/Main.hs b/test/Main.hs index 9ab09c5..ec23eaf 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -235,6 +235,15 @@ term_sparse = fromNamed $ lambda #inp $ body $ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) +term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64) +term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ + idx0 $ sum1i $ + let_ #hei (snd_ (fst_ (shape #mat))) $ + let_ #wid (snd_ (shape #mat)) $ + build1 #hei $ #i :-> + idx0 (sum1i (build1 #wid $ #j :-> + #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) + tests :: TestTree tests = testGroup "AD" [adTest "id" $ fromNamed $ lambda #x $ body $ #x @@ -284,14 +293,7 @@ tests = testGroup "AD" let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m - ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) $ - fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ - idx0 $ sum1i $ - let_ #hei (snd_ (fst_ (shape #mat))) $ - let_ #wid (snd_ (shape #mat)) $ - build1 #hei $ #i :-> - idx0 (sum1i (build1 #wid $ #j :-> - #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) + ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM |