diff options
-rw-r--r-- | src/AST.hs | 6 | ||||
-rw-r--r-- | src/AST/Count.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 5 | ||||
-rw-r--r-- | src/CHAD.hs | 14 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 7 | ||||
-rw-r--r-- | src/Language.hs | 4 | ||||
-rw-r--r-- | src/Language/AST.hs | 4 | ||||
-rw-r--r-- | src/Simplify.hs | 4 | ||||
-rw-r--r-- | test/Main.hs | 10 |
10 files changed, 41 insertions, 17 deletions
@@ -83,7 +83,7 @@ data Expr x env t where EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) @@ -186,7 +186,7 @@ typeOf = \case EConstArr _ n t _ -> STArr n (STScal t) EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) EBuild _ n _ e -> STArr n (typeOf e) - EFold1Inner _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t @@ -260,7 +260,7 @@ subst' f w = \case EConstArr x n t a -> EConstArr x n t a EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x a b -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) + EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index ad68685..dbec446 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -116,7 +116,7 @@ occCountGeneral onehot unpush alter many = go WId EConstArr{} -> mempty EBuild1 _ a b -> re a <> many (re1 b) EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b + EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c ESum1Inner _ e -> re e EUnit _ e -> re e EReplicate1Inner _ a b -> re a <> re b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 8f1fe67..d811912 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -124,14 +124,15 @@ ppExpr' d val = \case return $ showParen (d > 10) $ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" - EFold1Inner _ a b -> do + EFold1Inner _ a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c return $ showParen (d > 10) $ showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' - . showString ") " . b' + . showString ") " . b' . showString " " . c' ESum1Inner _ e -> do e' <- ppExpr' 11 val e diff --git a/src/CHAD.hs b/src/CHAD.hs index d05e77f..786de07 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1037,6 +1037,19 @@ drev des = \case (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ weakenExpr (WCopy WSink) e2) + EReplicate1Inner _ en e + -- We're allowed to ignore en2 here because the output of 'ei' is discrete. + | Rets binds (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) + <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil + , let STArr ndim eltty = typeOf e -> + Ret (binds `BPush` (d1 (typeOf e), e1)) + (weakenExpr WSink $ EReplicate1Inner ext en1 e1) + sub + (ELet ext (EFold1Inner ext (EPlus eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero eltty) + (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + EIdx0 _ e | Ret e0 e1 sub e2 <- drev des e , STArr _ t <- typeOf e -> @@ -1097,7 +1110,6 @@ drev des = \case weakenExpr (WCopy (WSink .> WSink)) e2) -- These should be the next to be implemented, I think - EReplicate1Inner{} -> err_unsupported "EReplicate1Inner" EFold1Inner{} -> err_unsupported "EFold1Inner" ENothing{} -> err_unsupported "ENothing" diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 3e45ce7..a93b8e6 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -153,7 +153,7 @@ dfwdDN = \case EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) - EFold1Inner _ a b -> EFold1Inner ext (dfwdDN a) (dfwdDN b) + EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e pairty = (STPair (STScal t) (STScal t)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 8ce1b0e..b818eb0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -75,11 +75,12 @@ interpret' env = \case EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) - EFold1Inner _ a b -> do + EFold1Inner _ a b c -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a - arr <- interpret' env b + x0 <- interpret' env b + arr <- interpret' env c let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e diff --git a/src/Language.hs b/src/Language.hs index 80de713..c2b844e 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -63,8 +63,8 @@ build1 a (v :-> b) = NEBuild1 a v b build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) build n a (v :-> b) = NEBuild n a v b -fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 = NEFold1Inner v1 v2 e1 e2 +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) sum1i e = NESum1Inner e diff --git a/src/Language/AST.hs b/src/Language/AST.hs index af5a5a2..0945dd9 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -43,7 +43,7 @@ data NExpr env t where NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) @@ -123,7 +123,7 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b) NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) + NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) diff --git a/src/Simplify.hs b/src/Simplify.hs index 2007585..3f4c8e3 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -87,7 +87,7 @@ simplify' = \case EConstArr _ n t v -> EConstArr ext n t v EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b) - EFold1Inner _ a b -> EFold1Inner ext (simplify' a) (simplify' b) + EFold1Inner _ a b c -> EFold1Inner ext (simplify' a) (simplify' b) (simplify' c) ESum1Inner _ e -> ESum1Inner ext (simplify' e) EUnit _ e -> EUnit ext (simplify' e) EReplicate1Inner _ a b -> EReplicate1Inner ext (simplify' a) (simplify' b) @@ -129,7 +129,7 @@ hasAdds = \case EConstArr _ _ _ _ -> False EBuild1 _ a b -> hasAdds a || hasAdds b EBuild _ _ a b -> hasAdds a || hasAdds b - EFold1Inner _ a b -> hasAdds a || hasAdds b + EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b diff --git a/test/Main.hs b/test/Main.hs index d90d9cd..f779352 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -229,8 +229,18 @@ tests :: IO Bool tests = checkParallel $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) + ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) + ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) + ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $ + idx0 $ sum1i $ replicate1i 10 #x) + + ,("pairs", adTest $ fromNamed $ lambda #x $ lambda #y $ body $ + let_ #p (pair #x #y) $ + let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ + fst_ #q * #x + snd_ #q * fst_ #p) + ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) |