diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 20 |
1 files changed, 15 insertions, 5 deletions
@@ -25,6 +25,7 @@ import Data.Kind (Type) import Array import AST.Accum +import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types @@ -91,11 +92,16 @@ data Expr x env t where ERecompute :: x t -> Expr x env t -> Expr x env t -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t @@ -218,9 +224,10 @@ typeOf = \case ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil + EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t @@ -261,8 +268,9 @@ extOf = \case ECustom x _ _ _ _ _ _ _ _ -> x ERecompute x _ -> x EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x + EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x @@ -306,8 +314,9 @@ travExt f = \case ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 - EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s @@ -364,8 +373,9 @@ subst' f w = \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s |