diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-20 18:32:22 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-20 18:32:22 +0100 |
commit | d030802dd6d960afa80ac84a5580a46d39c02822 (patch) | |
tree | 0c40e8eea6fe12cab0bd74e5e4f457e13bbf9afd | |
parent | 146a846f799f63cd98eee2149c417686adba17a9 (diff) |
Commutativity marker on fold1i
-rw-r--r-- | src/AST.hs | 11 | ||||
-rw-r--r-- | src/AST/Count.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 6 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
-rw-r--r-- | src/Analysis/Identity.hs | 4 | ||||
-rw-r--r-- | src/CHAD.hs | 7 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 2 | ||||
-rw-r--r-- | src/Language/AST.hs | 2 | ||||
-rw-r--r-- | src/Simplify.hs | 4 |
10 files changed, 24 insertions, 18 deletions
@@ -60,7 +60,7 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) @@ -106,6 +106,9 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () +data Commutative = Commut | Noncommut + deriving (Show) + type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -182,7 +185,7 @@ typeOf = \case EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) - EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t @@ -223,7 +226,7 @@ extOf = \case EMaybe x _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x - EFold1Inner x _ _ _ -> x + EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x EReplicate1Inner x _ _ -> x @@ -292,7 +295,7 @@ subst' f w = \case EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index c0d8d2d..dc8ec72 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -115,7 +115,7 @@ occCountGeneral onehot unpush alter many = go WId EMaybe _ a b e -> re a <> re1 b <> re e EConstArr{} -> mempty EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c + EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c ESum1Inner _ e -> re e EUnit _ e -> re e EReplicate1Inner _ a b -> re a <> re b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index b9406d7..527a7ca 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -150,14 +150,16 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e']) - EFold1Inner _ a b c -> do + EFold1Inner _ cm a b c -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c + let opname = case cm of Commut -> "fold1i(C)" + Noncommut -> "fold1i" return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ae9728a..b30f7a0 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -29,7 +29,7 @@ unMonoid = \case EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) - EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 5e36dde..095d0fa 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -145,7 +145,7 @@ idana env expr = case expr of res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') - EFold1Inner _ e1 e2 e3 -> do + EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 x1 <- genIds t1 x2 <- genIds t1 @@ -154,7 +154,7 @@ idana env expr = case expr of (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 res <- VIArr <$> genId <*> pure sh - pure (res, EFold1Inner res e1' e2' e3') + pure (res, EFold1Inner res cm e1' e2' e3') ESum1Inner _ e1 -> do (v1, e1') <- idana env e1 diff --git a/src/CHAD.hs b/src/CHAD.hs index d7d7da2..a5a5719 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -896,9 +896,10 @@ drev des = \case subtape (EReplicate1Inner ext en1 e1) sub - (ELet ext (EFold1Inner ext (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ + (ELet ext (EFold1Inner ext Commut + (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext eltty) + (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $ weakenExpr (WCopy WSink) e2) EIdx0 _ e diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index aa35a5b..9a95f81 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -146,7 +146,7 @@ dfwdDN = \case (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) - EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e pairty = (STPair (STScal t) (STScal t)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 11184c9..3cc7ae4 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -94,7 +94,7 @@ interpret'Rec env = \case EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) - EFold1Inner _ a b c -> do + EFold1Inner _ _ a b c -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 387915b..b36e151 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -185,7 +185,7 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c) + NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) diff --git a/src/Simplify.hs b/src/Simplify.hs index ac1bb8b..0bf5482 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -171,7 +171,7 @@ simplify' = \case EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b - EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c + EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c ESum1Inner _ e -> ESum1Inner ext <$> simplify' e EUnit _ e -> EUnit ext <$> simplify' e EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b @@ -224,7 +224,7 @@ hasAdds = \case EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b - EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b |