summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs6
-rw-r--r--src/AST/Count.hs2
-rw-r--r--src/AST/Pretty.hs5
-rw-r--r--src/CHAD.hs14
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs7
-rw-r--r--src/Language.hs4
-rw-r--r--src/Language/AST.hs4
-rw-r--r--src/Simplify.hs4
-rw-r--r--test/Main.hs10
10 files changed, 41 insertions, 17 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 94c8537..e2702ab 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -83,7 +83,7 @@ data Expr x env t where
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
@@ -186,7 +186,7 @@ typeOf = \case
EConstArr _ n t _ -> STArr n (STScal t)
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
EBuild _ n _ e -> STArr n (typeOf e)
- EFold1Inner _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
@@ -260,7 +260,7 @@ subst' f w = \case
EConstArr x n t a -> EConstArr x n t a
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EFold1Inner x a b -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index ad68685..dbec446 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -116,7 +116,7 @@ occCountGeneral onehot unpush alter many = go WId
EConstArr{} -> mempty
EBuild1 _ a b -> re a <> many (re1 b)
EBuild _ _ a b -> re a <> many (re1 b)
- EFold1Inner _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b
+ EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
ESum1Inner _ e -> re e
EUnit _ e -> re e
EReplicate1Inner _ a b -> re a <> re b
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 8f1fe67..d811912 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -124,14 +124,15 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
- EFold1Inner _ a b -> do
+ EFold1Inner _ a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
return $ showParen (d > 10) $
showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
- . showString ") " . b'
+ . showString ") " . b' . showString " " . c'
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
diff --git a/src/CHAD.hs b/src/CHAD.hs
index d05e77f..786de07 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1037,6 +1037,19 @@ drev des = \case
(ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
weakenExpr (WCopy WSink) e2)
+ EReplicate1Inner _ en e
+ -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
+ | Rets binds (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
+ <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil
+ , let STArr ndim eltty = typeOf e ->
+ Ret (binds `BPush` (d1 (typeOf e), e1))
+ (weakenExpr WSink $ EReplicate1Inner ext en1 e1)
+ sub
+ (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero eltty)
+ (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
EIdx0 _ e
| Ret e0 e1 sub e2 <- drev des e
, STArr _ t <- typeOf e ->
@@ -1097,7 +1110,6 @@ drev des = \case
weakenExpr (WCopy (WSink .> WSink)) e2)
-- These should be the next to be implemented, I think
- EReplicate1Inner{} -> err_unsupported "EReplicate1Inner"
EFold1Inner{} -> err_unsupported "EFold1Inner"
ENothing{} -> err_unsupported "ENothing"
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 3e45ce7..a93b8e6 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -153,7 +153,7 @@ dfwdDN = \case
EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
- EFold1Inner _ a b -> EFold1Inner ext (dfwdDN a) (dfwdDN b)
+ EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
ESum1Inner _ e ->
let STArr n (STScal t) = typeOf e
pairty = (STPair (STScal t) (STScal t))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 8ce1b0e..b818eb0 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -75,11 +75,12 @@ interpret' env = \case
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
- EFold1Inner _ a b -> do
+ EFold1Inner _ a b c -> do
let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
- arr <- interpret' env b
+ x0 <- interpret' env b
+ arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
- arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
ESum1Inner _ e -> do
arr <- interpret' env e
let STArr _ (STScal t) = typeOf e
diff --git a/src/Language.hs b/src/Language.hs
index 80de713..c2b844e 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -63,8 +63,8 @@ build1 a (v :-> b) = NEBuild1 a v b
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
-fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
-fold1i (v1 :-> v2 :-> e1) e2 = NEFold1Inner v1 v2 e1 e2
+fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3
sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index af5a5a2..0945dd9 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -43,7 +43,7 @@ data NExpr env t where
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t)
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
- NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
@@ -123,7 +123,7 @@ fromNamedExpr val = \case
NEConstArr n t x -> EConstArr ext n t x
NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b)
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1Inner n1 n2 a b -> EFold1Inner ext (lambda2 val n1 n2 a) (go b)
+ NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 2007585..3f4c8e3 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -87,7 +87,7 @@ simplify' = \case
EConstArr _ n t v -> EConstArr ext n t v
EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)
- EFold1Inner _ a b -> EFold1Inner ext (simplify' a) (simplify' b)
+ EFold1Inner _ a b c -> EFold1Inner ext (simplify' a) (simplify' b) (simplify' c)
ESum1Inner _ e -> ESum1Inner ext (simplify' e)
EUnit _ e -> EUnit ext (simplify' e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (simplify' a) (simplify' b)
@@ -129,7 +129,7 @@ hasAdds = \case
EConstArr _ _ _ _ -> False
EBuild1 _ a b -> hasAdds a || hasAdds b
EBuild _ _ a b -> hasAdds a || hasAdds b
- EFold1Inner _ a b -> hasAdds a || hasAdds b
+ EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c
ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
diff --git a/test/Main.hs b/test/Main.hs
index d90d9cd..f779352 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -229,8 +229,18 @@ tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
+ ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)
+
,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))
+ ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $
+ idx0 $ sum1i $ replicate1i 10 #x)
+
+ ,("pairs", adTest $ fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #p (pair #x #y) $
+ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
+ fst_ #q * #x + snd_ #q * fst_ #p)
+
,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)