aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
commit4d456e4d34b1e4fb3725051d1b8a0c376b704692 (patch)
tree1385217efcc0b58ddb028e707e6a5a36b884ed65 /test/Main.hs
parent0e8e59c5f9af547cf1b79b9bae892e32700ace56 (diff)
Implement reshape
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs14
1 files changed, 14 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs
index cb10ed6..2acc9f8 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -670,6 +670,20 @@ tests_AD = testGroup "AD"
,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree
+ ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $
+ let_ #sh (shape #a) $
+ let_ #n (snd_ #sh * snd_ (fst_ #sh)) $
+ idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a
+
+ ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $
+ let_ #sh (shape #a) $
+ let_ #innern (snd_ #sh) $
+ let_ #n (#innern * snd_ (fst_ #sh)) $
+ let_ #flata (reshape (SS SZ) (pair nil #n) #a) $
+ -- ensure the input array to EReshape is shared
+ idx0 $ sum1i $
+ build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern))
+
,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a