summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-20 18:32:22 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-20 18:32:22 +0100
commitd030802dd6d960afa80ac84a5580a46d39c02822 (patch)
tree0c40e8eea6fe12cab0bd74e5e4f457e13bbf9afd
parent146a846f799f63cd98eee2149c417686adba17a9 (diff)
Commutativity marker on fold1i
-rw-r--r--src/AST.hs11
-rw-r--r--src/AST/Count.hs2
-rw-r--r--src/AST/Pretty.hs6
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/Analysis/Identity.hs4
-rw-r--r--src/CHAD.hs7
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs2
-rw-r--r--src/Language/AST.hs2
-rw-r--r--src/Simplify.hs4
10 files changed, 24 insertions, 18 deletions
diff --git a/src/AST.hs b/src/AST.hs
index ecd4647..a4898c0 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -60,7 +60,7 @@ data Expr x env t where
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
@@ -106,6 +106,9 @@ type Ex = Expr (Const ())
ext :: Const () a
ext = Const ()
+data Commutative = Commut | Noncommut
+ deriving (Show)
+
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
@@ -182,7 +185,7 @@ typeOf = \case
EConstArr _ n t _ -> STArr n (STScal t)
EBuild _ n _ e -> STArr n (typeOf e)
- EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
@@ -223,7 +226,7 @@ extOf = \case
EMaybe x _ _ _ -> x
EConstArr x _ _ _ -> x
EBuild x _ _ _ -> x
- EFold1Inner x _ _ _ -> x
+ EFold1Inner x _ _ _ _ -> x
ESum1Inner x _ -> x
EUnit x _ -> x
EReplicate1Inner x _ _ -> x
@@ -292,7 +295,7 @@ subst' f w = \case
EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index c0d8d2d..dc8ec72 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -115,7 +115,7 @@ occCountGeneral onehot unpush alter many = go WId
EMaybe _ a b e -> re a <> re1 b <> re e
EConstArr{} -> mempty
EBuild _ _ a b -> re a <> many (re1 b)
- EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
+ EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
ESum1Inner _ e -> re e
EUnit _ e -> re e
EReplicate1Inner _ a b -> re a <> re b
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index b9406d7..527a7ca 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -150,14 +150,16 @@ ppExpr' d val expr = case expr of
<> hardline <> e')
(ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e'])
- EFold1Inner _ a b c -> do
+ EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
+ let opname = case cm of Commut -> "fold1i(C)"
+ Noncommut -> "fold1i"
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString "fold1i") <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ae9728a..b30f7a0 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -29,7 +29,7 @@ unMonoid = \case
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
- EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 5e36dde..095d0fa 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -145,7 +145,7 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> shidsToVec dim shids
pure (res, EBuild res dim e1' e2')
- EFold1Inner _ e1 e2 e3 -> do
+ EFold1Inner _ cm e1 e2 e3 -> do
let t1 = typeOf e1
x1 <- genIds t1
x2 <- genIds t1
@@ -154,7 +154,7 @@ idana env expr = case expr of
(v3, e3') <- idana env e3
let VIArr _ (_ :< sh) = v3
res <- VIArr <$> genId <*> pure sh
- pure (res, EFold1Inner res e1' e2' e3')
+ pure (res, EFold1Inner res cm e1' e2' e3')
ESum1Inner _ e1 -> do
(v1, e1') <- idana env e1
diff --git a/src/CHAD.hs b/src/CHAD.hs
index d7d7da2..a5a5719 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -896,9 +896,10 @@ drev des = \case
subtape
(EReplicate1Inner ext en1 e1)
sub
- (ELet ext (EFold1Inner ext (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (EZero ext eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
+ (ELet ext (EFold1Inner ext Commut
+ (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext eltty)
+ (EVar ext (STArr (SS ndim) (d2 eltty)) IZ)) $
weakenExpr (WCopy WSink) e2)
EIdx0 _ e
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index aa35a5b..9a95f81 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -146,7 +146,7 @@ dfwdDN = \case
(EConstArr ext n t x)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
- EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)
ESum1Inner _ e ->
let STArr n (STScal t) = typeOf e
pairty = (STPair (STScal t) (STScal t))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 11184c9..3cc7ae4 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -94,7 +94,7 @@ interpret'Rec env = \case
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
- EFold1Inner _ a b c -> do
+ EFold1Inner _ _ a b c -> do
let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 387915b..b36e151 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -185,7 +185,7 @@ fromNamedExpr val = \case
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index ac1bb8b..0bf5482 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -171,7 +171,7 @@ simplify' = \case
EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
EConstArr _ n t v -> pure $ EConstArr ext n t v
EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
- EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c
ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
EUnit _ e -> EUnit ext <$> simplify' e
EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
@@ -224,7 +224,7 @@ hasAdds = \case
EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
EConstArr _ _ _ _ -> False
EBuild _ _ a b -> hasAdds a || hasAdds b
- EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
EReplicate1Inner _ a b -> hasAdds a || hasAdds b