From 4d456e4d34b1e4fb3725051d1b8a0c376b704692 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:56:35 +0100 Subject: Implement reshape --- test/Main.hs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'test') diff --git a/test/Main.hs b/test/Main.hs index cb10ed6..2acc9f8 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -670,6 +670,20 @@ tests_AD = testGroup "AD" ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree + ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ + idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + + ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #innern (snd_ #sh) $ + let_ #n (#innern * snd_ (fst_ #sh)) $ + let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ + -- ensure the input array to EReshape is shared + idx0 $ sum1i $ + build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) + ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a -- cgit v1.2.3-70-g09d2