diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 14 | 
1 files changed, 14 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs index cb10ed6..2acc9f8 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -670,6 +670,20 @@ tests_AD = testGroup "AD"    ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree +  ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ +    let_ #sh (shape #a) $ +    let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ +      idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + +  ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ +    let_ #sh (shape #a) $ +    let_ #innern (snd_ #sh) $ +    let_ #n (#innern * snd_ (fst_ #sh)) $ +    let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ +      -- ensure the input array to EReshape is shared +      idx0 $ sum1i $ +        build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) +    ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $      idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a  | 
