From 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:58:08 +0100 Subject: Simplify foldD2 to not sum x0 contributions --- src/AST.hs | 42 ++++++++++++++++++++++++++++++------------ src/AST/Count.hs | 38 +++++++++++++------------------------- src/AST/Pretty.hs | 11 +++-------- src/AST/SplitLets.hs | 8 ++++---- src/AST/UnMonoid.hs | 2 +- src/Analysis/Identity.hs | 18 +++++++----------- src/CHAD.hs | 18 +++++++++++------- src/Interpreter.hs | 10 ++++------ src/Simplify.hs | 4 ++-- 9 files changed, 75 insertions(+), 76 deletions(-) (limited to 'src') diff --git a/src/AST.hs b/src/AST.hs index 7549ff0..663b83f 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -71,21 +71,24 @@ data Expr x env t where EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) - -- MapAccum-like (is it real mapaccum? If so, rename) + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 b) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) - -- Reverse derivative of Efold1Inner. - EFold1InnerD2 :: x (TPair t2 (TArr (S n) t2)) -> Commutative + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env t2 -- zero - -> Expr x (t2 : t2 : env) t2 -- plus - -> Expr x env (TArr (S n) b) -- extra data passed to function + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 -> Expr x env (TArr n t2) -- incoming cotangent - -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -237,7 +240,7 @@ typeOf = \case EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) - EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -287,7 +290,7 @@ extOf = \case EMinimum1Inner x _ -> x EReshape x _ _ _ -> x EFold1InnerD1 x _ _ _ _ -> x - EFold1InnerD2 x _ _ _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x EIdx0 x _ -> x EIdx1 x _ _ -> x @@ -336,7 +339,7 @@ travExt f = \case EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EConst x t v -> EConst <$> f x <*> pure t <*> pure v EIdx0 x e -> EIdx0 <$> f x <*> travExt f e EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b @@ -398,7 +401,7 @@ subst' f w = \case EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -565,6 +568,19 @@ eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) eshapeConst ShNil = ENil ext eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + -- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) -- ezeroD2 t ezi = EZero ext (d2M t) ezi @@ -594,7 +610,9 @@ esnd e = ESnd ext e elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) - = ELet ext rhs body + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body -- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) use :: Ex env a -> Ex env b -> Ex env b diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 66b4e0b..296c021 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -515,9 +515,8 @@ occCountX initialS topexpr k = case topexpr of EConstArr _ n t x -> case s of SsNone -> k OccEnd (\_ -> ENil ext) - SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) - SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x) - SsFull -> occCountX (SsArr SsFull) topexpr k + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) EBuild _ n a b -> case s of @@ -533,7 +532,7 @@ occCountX initialS topexpr k = case topexpr of weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ ENil ext) $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX s' b $ \env2'' mkb -> withSome (scaleMany (Some env2'')) $ \env2' -> @@ -543,7 +542,6 @@ occCountX initialS topexpr k = case topexpr of EBuild ext n (mka env') $ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) - SsFull -> occCountX (SsArr SsFull) topexpr k EFold1Inner _ commut a b c -> occCountX SsFull a $ \env1''' mka -> @@ -552,8 +550,7 @@ occCountX initialS topexpr k = case topexpr of occEnvPop' env1' $ \env1 s1 -> let s0 = case s of SsNone -> Some SsNone - SsArr s' -> Some s' - SsFull -> Some SsFull in + SsArr' s' -> Some s' in withSome (Some s1 <> Some s2 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> occCountX (SsArr sElt) c $ \env3 mkc -> @@ -573,11 +570,10 @@ occCountX initialS topexpr k = case topexpr of occCountX SsNone e $ \env mke -> k env $ \env' -> use (mke env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX s' e $ \env mke -> k env $ \env' -> EUnit ext (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k EReplicate1Inner _ a b -> case s of @@ -587,13 +583,12 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mka env') $ use (mkb env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX (SsArr s') b $ \env2 mkb -> withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> EReplicate1Inner ext (mka env') (mkb env') - SsFull -> occCountX (SsArr SsFull) topexpr k EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e @@ -654,23 +649,18 @@ occCountX initialS topexpr k = case topexpr of EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) (mkb env') (mkc env') - EFold1InnerD2 _ cm ef ez eplus ebog ed -> + EFold1InnerD2 _ cm ef ebog ed -> -- TODO: propagate usage of duals occCountX SsFull ef $ \env1_2' mkef -> occEnvPop' env1_2' $ \env1_1' _ -> occEnvPop' env1_1' $ \env1' sB -> - occCountX SsFull ez $ \env2' mkez -> - occCountX SsFull eplus $ \env3_2' mkeplus -> - occEnvPop' env3_2' $ \env3_1' _ -> - occEnvPop' env3_1' $ \env3' _ -> - occCountX (SsArr sB) ebog $ \env4 mkebog -> - occCountX SsFull ed $ \env5 mked -> - withSome (scaleMany (Some env1') <> scaleMany (Some env2') <> scaleMany (Some env3') <> Some env4 <> Some env5) $ \env -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc SsFull s $ EFold1InnerD2 ext cm (mkef (OccPush (OccPush env' () sB) () SsFull)) - (mkez env') (mkeplus (OccPush (OccPush env' () SsFull) () SsFull)) (mkebog env') (mked env') EConst _ t x -> @@ -692,13 +682,12 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mka env') $ use (mkb env') $ ENil ext - SsArr s' -> + SsArr' s' -> occCountX (SsArr s') a $ \env1 mka -> occCountX SsFull b $ \env2 mkb -> withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> EIdx1 ext (mka env') (mkb env') - SsFull -> occCountX (SsArr SsFull) topexpr k EIdx _ a b -> case s of @@ -863,16 +852,15 @@ occCountX initialS topexpr k = case topexpr of occCountX SsNone e $ \env mke -> k env $ \env' -> use (mke env') $ ENil ext - SsArr SsNone -> + SsArr' SsNone -> occCountX (SsArr SsNone) e $ \env mke -> k env $ \env' -> elet (mke env') $ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr SsFull -> + SsArr' SsFull -> occCountX (SsArr SsFull) e $ \env mke -> k env $ \env' -> reduce (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 67197f9..68fc629 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -250,23 +250,18 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] - EFold1InnerD2 _ cm ef ez eplus ebog ed -> do + EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog - t2 = typeOf ez + STArr _ t2 = typeOf ed namef1 <- genNameIfUsedIn tB (IS IZ) ef namef2 <- genNameIfUsedIn t2 IZ ef ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef - ez' <- ppExpr' 11 val ez - namep1 <- genNameIfUsedIn t2 (IS IZ) eplus - namep2 <- genNameIfUsedIn t2 IZ eplus - eplus' <- ppExpr' 0 (Const namep2 `SCons` Const namep1 `SCons` val) eplus ebog' <- ppExpr' 11 val ebog ed' <- ppExpr' 11 val ed let opname = "fold1iD2" ++ ppCommut cm return $ ppParen (d > 10) $ ppApp (annotate AHighlight (ppString opname) <> ppX expr) - [ppLam [ppString namef1, ppString namef2] ef', ez' - ,ppLam [ppString namep1, ppString namep2] eplus', ebog', ed'] + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 73c1c67..d276e44 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -38,10 +38,10 @@ splitLets' = \sub -> \case EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD2 x cm a b c d e -> - let t2 = typeOf b - STArr _ tB = typeOf d - in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (split2 sub t2 t2 c) (splitLets' sub d) (splitLets' sub e) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index e5a9708..a22b73f 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -46,7 +46,7 @@ unMonoid = \case EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index b3a6664..6301dc1 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -261,21 +261,17 @@ idana env expr = case expr of res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') pure (res, EFold1InnerD1 res cm e1' e2' e3') - EFold1InnerD2 _ cm ef ez eplus ebog ed -> do + EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog - t2 = typeOf ez + STArr _ t2 = typeOf ed xf1 <- genIds t2 xf2 <- genIds tB (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef - (_, e2') <- idana env ez - xp1 <- genIds t2 - xp2 <- genIds t2 - (_, e3') <- idana (xp1 `SCons` xp2 `SCons` env) eplus - (v4, e4') <- idana env ebog - (_, e5') <- idana env ed - let VIArr _ sh = v4 - res <- VIPair <$> genIds t2 <*> (VIArr <$> genId <*> pure sh) - pure (res, EFold1InnerD2 res cm e1' e2' e3' e4' e5') + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') EConst _ t val -> do res <- VIScal <$> genId diff --git a/src/CHAD.hs b/src/CHAD.hs index 04c4231..7594a0f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1184,7 +1184,7 @@ drev des accumMap sd = \case subx₀af (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet - (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ + (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ makeAccumulators (autoWeak library #propr layout1) envPro $ let layout2 = #d2acPro :++: layout1 in EFold1InnerD2 ext commut @@ -1198,8 +1198,6 @@ drev des accumMap sd = \case .> wPro (subList (bindingsBinds ef0) subtapeEf)) ef2) $ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))) - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) (ezip (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) @@ -1207,10 +1205,16 @@ drev des accumMap sd = \case (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ plus_x₀a_f - (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ - plus_x₀_a - (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2) - (subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) (ESnd ext (evar IZ))) } diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 79d5014..9e3d2a6 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -164,15 +164,14 @@ interpret'Rec env = \case return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) - EFold1InnerD2 _ _ ef ez eplus ebog ed -> do + EFold1InnerD2 _ _ ef ebog ed -> do let STArr _ tB = typeOf ebog - t2 = typeOf ez + STArr _ t2 = typeOf ed let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef - zeroval <- interpret' env ez - let plusfun = \x y -> interpret' (V t2 y `SCons` V t2 x `SCons` env) eplus bog <- interpret' env ebog arrctg <- interpret' env ed let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" res <- arrayGenerateM sh $ \idx -> do let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) loop i !ctg !inpctgs = do @@ -181,8 +180,7 @@ interpret'Rec env = \case loop (i - 1) ctg1 (ctg2 : inpctgs) (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) - x0ctg <- foldM (\x (y, _) -> plusfun x y) zeroval (arrayToList res) - return (x0ctg + return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> arrayIndexLinear (snd (arrayIndex res idx)) i) EConst _ _ v -> return v diff --git a/src/Simplify.hs b/src/Simplify.hs index 74306a1..b89d7f6 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -316,7 +316,7 @@ simplify'Rec = \case EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] EReshape _ n a b -> [simprec| EReshape ext n *a *b |] EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] - EFold1InnerD2 _ cm a b c d e -> [simprec| EFold1InnerD2 ext cm *a *b *c *d *e |] + EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> [simprec| EIdx0 ext *e |] EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] @@ -372,7 +372,7 @@ hasAdds = \case EMinimum1Inner _ e -> hasAdds e EReshape _ _ a b -> hasAdds a || hasAdds b EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - EFold1InnerD2 _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e + EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e -- cgit v1.2.3-70-g09d2