summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Language.hs3
-rw-r--r--test/Main.hs23
2 files changed, 26 insertions, 0 deletions
diff --git a/src/Language.hs b/src/Language.hs
index cf7cc4c..4ed4eaa 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -149,6 +149,9 @@ infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
+length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
+length_ e = snd_ (shape e)
+
oper :: SOp a t -> NExpr env a -> NExpr env t
oper = NEOp
diff --git a/test/Main.hs b/test/Main.hs
index 9dad35e..0cad9c9 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -425,6 +425,27 @@ term_mulmatvec = fromNamed $ lambda #mat $ lambda #vec $ body $
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
+term_arr_rebind :: Ex '[I64, TVec R] R
+term_arr_rebind = fromNamed $ lambda #a $ lambda #k $ body $
+ let_ #n (if_ (#k .< length_ #a) #k (length_ #a)) $
+ let_ #b (build1 #n (#i :-> #a ! pair nil #i)) $
+ let_ #p (if_ (#n `mod_` 2 .== 1)
+ (pair #a #b)
+ (pair (map_ (#x :-> #x + 1) #a) #b)) $
+ if_ (#n `mod_` 3 .== 1)
+ (idx0 (sum1i (snd_ #p)))
+ (let_ #b' (snd_ #p) $
+ idx0 (sum1i #b') * idx0 (sum1i (map_ (#x :-> 2 * #x) #b')))
+
+-- This simplifies away to a pointless test, but is helpful for debugging what
+-- term_arr_rebind is supposed to test in a REPL
+term_arr_rebind_simple :: Ex '[TVec R] R
+term_arr_rebind_simple = fromNamed $ lambda #a $ body $
+ let_ #b (build1 (length_ #a) (#i :-> 5 * (#a ! pair nil #i))) $
+ let_ #c #b $
+ let_ #d #c $
+ idx0 (sum1i #d)
+
tests_Compile :: TestTree
tests_Compile = testGroup "Compile"
[compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $
@@ -517,6 +538,8 @@ tests_AD = testGroup "AD"
#L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0))))
42
+ ,adTest "arr-rebind" term_arr_rebind
+
,adTestGen "neural" Example.neural gen_neural
,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural