aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:58:08 +0100
commit4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f (patch)
treee371c4962f1beee96cc68d55accffab16e18b97a /src
parent4d456e4d34b1e4fb3725051d1b8a0c376b704692 (diff)
Simplify foldD2 to not sum x0 contributions
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs42
-rw-r--r--src/AST/Count.hs38
-rw-r--r--src/AST/Pretty.hs11
-rw-r--r--src/AST/SplitLets.hs8
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/Analysis/Identity.hs18
-rw-r--r--src/CHAD.hs18
-rw-r--r--src/Interpreter.hs10
-rw-r--r--src/Simplify.hs4
9 files changed, 75 insertions, 76 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 7549ff0..663b83f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -71,21 +71,24 @@ data Expr x env t where
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
- -- MapAccum-like (is it real mapaccum? If so, rename)
+ -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
+ -- an implementation is allowed to parallelise this thing and store the b
+ -- values in some implementation-defined order.
+ -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
-> Expr x (t1 : t1 : env) (TPair t1 b)
-> Expr x env t1
-> Expr x env (TArr (S n) t1)
-> Expr x env (TPair (TArr n t1) -- normal primal fold output
(TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores)
- -- Reverse derivative of Efold1Inner.
- EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative
+ -- Reverse derivative of EFold1Inner. The contributions to the initial
+ -- element are not yet added together here; we assume a later fusion system
+ -- does that for us.
+ EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative
-> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
- -> Expr x env t2 -- zero
- -> Expr x (t2 : t2 : env) t2 -- plus
- -> Expr x env (TArr (S n) b) -- extra data passed to function
+ -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1
-> Expr x env (TArr n t2) -- incoming cotangent
- -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array
+ -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -237,7 +240,7 @@ typeOf = \case
EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
- EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2)
+ EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -287,7 +290,7 @@ extOf = \case
EMinimum1Inner x _ -> x
EReshape x _ _ _ -> x
EFold1InnerD1 x _ _ _ _ -> x
- EFold1InnerD2 x _ _ _ _ _ _ -> x
+ EFold1InnerD2 x _ _ _ _ -> x
EConst x _ _ -> x
EIdx0 x _ -> x
EIdx1 x _ _ -> x
@@ -336,7 +339,7 @@ travExt f = \case
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
- EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
EIdx0 x e -> EIdx0 <$> f x <*> travExt f e
EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b
@@ -398,7 +401,7 @@ subst' f w = \case
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
- EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e)
+ EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
@@ -565,6 +568,19 @@ eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
eshapeConst ShNil = ENil ext
eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
+eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx
+eshapeProd SZ _ = EConst ext STI64 1
+eshapeProd (SS SZ) e = ESnd ext e
+eshapeProd (SS n) e =
+ eunPair e $ \_ e1 e2 ->
+ EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2)
+
+eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t)
+eflatten e =
+ let STArr n _ = typeOf e
+ in elet e $
+ EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ)
+
-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
-- ezeroD2 t ezi = EZero ext (d2M t) ezi
@@ -594,7 +610,9 @@ esnd e = ESnd ext e
elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
elet rhs body
| Dict <- styKnown (typeOf rhs)
- = ELet ext rhs body
+ = if cheapExpr rhs
+ then substInline rhs body
+ else ELet ext rhs body
-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
use :: Ex env a -> Ex env b -> Ex env b
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 66b4e0b..296c021 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -515,9 +515,8 @@ occCountX initialS topexpr k = case topexpr of
EConstArr _ n t x ->
case s of
SsNone -> k OccEnd (\_ -> ENil ext)
- SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
- SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
- SsFull -> occCountX (SsArr SsFull) topexpr k
+ SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
EBuild _ n a b ->
case s of
@@ -533,7 +532,7 @@ occCountX initialS topexpr k = case topexpr of
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
ENil ext) $
ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
withSome (scaleMany (Some env2'')) $ \env2' ->
@@ -543,7 +542,6 @@ occCountX initialS topexpr k = case topexpr of
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
- SsFull -> occCountX (SsArr SsFull) topexpr k
EFold1Inner _ commut a b c ->
occCountX SsFull a $ \env1''' mka ->
@@ -552,8 +550,7 @@ occCountX initialS topexpr k = case topexpr of
occEnvPop' env1' $ \env1 s1 ->
let s0 = case s of
SsNone -> Some SsNone
- SsArr s' -> Some s'
- SsFull -> Some SsFull in
+ SsArr' s' -> Some s' in
withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
occCountX (SsArr sElt) c $ \env3 mkc ->
@@ -573,11 +570,10 @@ occCountX initialS topexpr k = case topexpr of
occCountX SsNone e $ \env mke ->
k env $ \env' ->
use (mke env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX s' e $ \env mke ->
k env $ \env' ->
EUnit ext (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EReplicate1Inner _ a b ->
case s of
@@ -587,13 +583,12 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mka env') $ use (mkb env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX (SsArr s') b $ \env2 mkb ->
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
EReplicate1Inner ext (mka env') (mkb env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
@@ -654,23 +649,18 @@ occCountX initialS topexpr k = case topexpr of
EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
(mkb env') (mkc env')
- EFold1InnerD2 _ cm ef ez eplus ebog ed ->
+ EFold1InnerD2 _ cm ef ebog ed ->
-- TODO: propagate usage of duals
occCountX SsFull ef $ \env1_2' mkef ->
occEnvPop' env1_2' $ \env1_1' _ ->
occEnvPop' env1_1' $ \env1' sB ->
- occCountX SsFull ez $ \env2' mkez ->
- occCountX SsFull eplus $ \env3_2' mkeplus ->
- occEnvPop' env3_2' $ \env3_1' _ ->
- occEnvPop' env3_1' $ \env3' _ ->
- occCountX (SsArr sB) ebog $ \env4 mkebog ->
- occCountX SsFull ed $ \env5 mked ->
- withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env ->
+ occCountX (SsArr sB) ebog $ \env2 mkebog ->
+ occCountX SsFull ed $ \env3 mked ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc SsFull s $
EFold1InnerD2 ext cm
(mkef (OccPush (OccPush env' () sB) () SsFull))
- (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull))
(mkebog env') (mked env')
EConst _ t x ->
@@ -692,13 +682,12 @@ occCountX initialS topexpr k = case topexpr of
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
use (mka env') $ use (mkb env') $ ENil ext
- SsArr s' ->
+ SsArr' s' ->
occCountX (SsArr s') a $ \env1 mka ->
occCountX SsFull b $ \env2 mkb ->
withSome (Some env1 <> Some env2) $ \env ->
k env $ \env' ->
EIdx1 ext (mka env') (mkb env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
EIdx _ a b ->
case s of
@@ -863,16 +852,15 @@ occCountX initialS topexpr k = case topexpr of
occCountX SsNone e $ \env mke ->
k env $ \env' ->
use (mke env') $ ENil ext
- SsArr SsNone ->
+ SsArr' SsNone ->
occCountX (SsArr SsNone) e $ \env mke ->
k env $ \env' ->
elet (mke env') $
EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr SsFull ->
+ SsArr' SsFull ->
occCountX (SsArr SsFull) e $ \env mke ->
k env $ \env' ->
reduce (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 67197f9..68fc629 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -250,23 +250,18 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
- EFold1InnerD2 _ cm ef ez eplus ebog ed -> do
+ EFold1InnerD2 _ cm ef ebog ed -> do
let STArr _ tB = typeOf ebog
- t2 = typeOf ez
+ STArr _ t2 = typeOf ed
namef1 <- genNameIfUsedIn tB (IS IZ) ef
namef2 <- genNameIfUsedIn t2 IZ ef
ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
- ez' <- ppExpr' 11 val ez
- namep1 <- genNameIfUsedIn t2 (IS IZ) eplus
- namep2 <- genNameIfUsedIn t2 IZ eplus
- eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus
ebog' <- ppExpr' 11 val ebog
ed' <- ppExpr' 11 val ed
let opname = "fold1iD2" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr)
- [ppLam [ppString namef1, ppString namef2] ef', ez'
- ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed']
+ [ppLam [ppString namef1, ppString namef2] ef', ebog', ed']
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 73c1c67..d276e44 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -38,10 +38,10 @@ splitLets' = \sub -> \case
EFold1InnerD1 x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
- EFold1InnerD2 x cm a b c d e ->
- let t2 = typeOf b
- STArr _ tB = typeOf d
- in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e)
+ EFold1InnerD2 x cm a b c ->
+ let STArr _ tB = typeOf b
+ STArr _ t2 = typeOf c
+ in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index e5a9708..a22b73f 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -46,7 +46,7 @@ unMonoid = \case
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index b3a6664..6301dc1 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -261,21 +261,17 @@ idana env expr = case expr of
res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh')
pure (res, EFold1InnerD1 res cm e1' e2' e3')
- EFold1InnerD2 _ cm ef ez eplus ebog ed -> do
+ EFold1InnerD2 _ cm ef ebog ed -> do
let STArr _ tB = typeOf ebog
- t2 = typeOf ez
+ STArr _ t2 = typeOf ed
xf1 <- genIds t2
xf2 <- genIds tB
(_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef
- (_, e2') <- idana env ez
- xp1 <- genIds t2
- xp2 <- genIds t2
- (_, e3') <- idana (xp1 `SCons` xp2 `SCons` env) eplus
- (v4, e4') <- idana env ebog
- (_, e5') <- idana env ed
- let VIArr _ sh = v4
- res <- VIPair <$> genIds t2 <*> (VIArr <$> genId <*> pure sh)
- pure (res, EFold1InnerD2 res cm e1' e2' e3' e4' e5')
+ (v2, e2') <- idana env ebog
+ (_, e3') <- idana env ed
+ let VIArr _ sh@(_ :< sh') = v2
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh)
+ pure (res, EFold1InnerD2 res cm e1' e2' e3')
EConst _ t val -> do
res <- VIScal <$> genId
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 04c4231..7594a0f 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1184,7 +1184,7 @@ drev des accumMap sd = \case
subx₀af
(let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
- (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
+ (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $
makeAccumulators (autoWeak library #propr layout1) envPro $
let layout2 = #d2acPro :++: layout1 in
EFold1InnerD2 ext commut
@@ -1198,8 +1198,6 @@ drev des accumMap sd = \case
.> wPro (subList (bindingsBinds ef0) subtapeEf))
ef2) $
EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)))
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
(ezip
(EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
(ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
@@ -1207,10 +1205,16 @@ drev des accumMap sd = \case
(EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
(EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
plus_x₀a_f
- (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
- plus_x₀_a
- (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2)
- (subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (plus_x₀_a
+ (elet (EIdx0 ext
+ (EFold1Inner ext Commut
+ (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
+ (eflatten (EFst ext (EFst ext (evar IZ)))))) $
+ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
+ ex₀2)
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
(ESnd ext (evar IZ)))
}
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 79d5014..9e3d2a6 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -164,15 +164,14 @@ interpret'Rec env = \case
return (arrayMap fst res
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
arrayIndexLinear (snd (arrayIndex res idx)) i)
- EFold1InnerD2 _ _ ef ez eplus ebog ed -> do
+ EFold1InnerD2 _ _ ef ebog ed -> do
let STArr _ tB = typeOf ebog
- t2 = typeOf ez
+ STArr _ t2 = typeOf ed
let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef
- zeroval <- interpret' env ez
- let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus
bog <- interpret' env ebog
arrctg <- interpret' env ed
let sh `ShCons` n = arrayShape bog
+ when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2"
res <- arrayGenerateM sh $ \idx -> do
let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs)
loop i !ctg !inpctgs = do
@@ -181,8 +180,7 @@ interpret'Rec env = \case
loop (i - 1) ctg1 (ctg2 : inpctgs)
(x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) []
return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg)
- x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res)
- return (x0ctg
+ return (arrayMap fst res
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
arrayIndexLinear (snd (arrayIndex res idx)) i)
EConst _ _ v -> return v
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 74306a1..b89d7f6 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -316,7 +316,7 @@ simplify'Rec = \case
EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EReshape _ n a b -> [simprec| EReshape ext n *a *b |]
EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
- EFold1InnerD2 _ cm a b c d e -> [simprec| EFold1InnerD2 ext cm *a *b *c *d *e |]
+ EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]
EConst _ t v -> pure $ EConst ext t v
EIdx0 _ e -> [simprec| EIdx0 ext *e |]
EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
@@ -372,7 +372,7 @@ hasAdds = \case
EMinimum1Inner _ e -> hasAdds e
EReshape _ _ a b -> hasAdds a || hasAdds b
EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
- EFold1InnerD2 _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
+ EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e