diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-07 15:11:59 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 15:37:29 +0100 |
commit | 137eaa13144c2599ac29da9ebd3af24ac1ce8968 (patch) | |
tree | 8fc5221824f671dfc27f8064e3fc537859bb73e8 | |
parent | 1abb0c11efd2ba650c0a20de8047efbde2cc6adf (diff) |
WIP revamp accumulator projection type repr
I stopped working on this because I realised that having sparse products
(and coproducts, prehaps) everywhere is a very bad idea in general, and
that we need to fix that first before really being able to do anything
else productive with performance.
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 75 | ||||
-rw-r--r-- | src/AST/Accum.hs | 44 | ||||
-rw-r--r-- | src/AST/Count.hs | 4 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 88 | ||||
-rw-r--r-- | src/Data.hs | 4 |
6 files changed, 113 insertions, 103 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index aa4dfcc..e201683 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -13,6 +13,7 @@ library Analysis.Identity Array AST + AST.Accum AST.Bindings AST.Count AST.Env @@ -16,42 +16,19 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Weaken) where +module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) import Array +import AST.Accum import AST.Types import AST.Weaken import CHAD.Types import Data --- | This index is flipped around from the usual direction: the smallest index --- is at the _heart_ of the nesting, not at the outside. The outermost layer --- indexes into the _outer_ dimension of the type @t@. This makes indices into --- compound structures work properly with coproducts. -type family AcIdx t i where - AcIdx t Z = TNil - AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) - AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) - AcIdx (TMaybe t) (S i) = AcIdx t i - AcIdx (TArr Z t) (S i) = AcIdx t i - AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) - -type family AcVal t i where - AcVal t Z = t - AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) - AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) - AcVal (TMaybe t) (S i) = AcVal t i - AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i)) - -type family AcValArr n t i where - AcValArr n t Z = TArr n t - AcValArr Z t (S i) = AcVal t i - AcValArr (S n) t (S i) = AcValArr n t i - -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. @@ -110,15 +87,14 @@ data Expr x env t where -> Expr x env a -> Expr x env b -> Expr x env t - -- accumulation effect - EWith :: x (TPair a t) -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil - -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil + -- accumulation effect on monoids + EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum (D2 t) : env) a -> Expr x env (TPair a (D2 t)) + EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum (D2 a)) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t) + EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) -- partiality EError :: x a -> STy a -> String -> Expr x env a @@ -129,9 +105,6 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) - type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -224,8 +197,8 @@ typeOf = \case ECustom _ _ _ _ e _ _ _ _ -> typeOf e - EWith _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ -> STNil + EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ _ _ _ -> STNil EZero _ t -> d2 t EPlus _ t _ _ -> d2 t @@ -262,8 +235,8 @@ extOf = \case EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x - EWith x _ _ -> x - EAccum x _ _ _ _ -> x + EWith x _ _ _ -> x + EAccum x _ _ _ _ _ -> x EZero x _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x @@ -331,11 +304,11 @@ subst' f w = \case EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - EWith x e1 e2 -> EWith x (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x i e1 e2 e3 -> EAccum x i (subst' f w e1) (subst' f w e2) (subst' f w e3) + EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero x t -> EZero x t EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t i a b -> EOneHot x t i (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -396,6 +369,9 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup = mkTup (ENil ext) (EPair ext) + ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ @@ -456,22 +432,3 @@ eshapeEmpty (SS n) e = (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) - -arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n) -arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext) - where - -- symbolic version of 'invert' in Interpreter - go :: forall n m t env proxy. proxy t -> SNat n -> SNat m - -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m)) - go _ SZ _ _ acidx = acidx - go p (SS n) m idx acidx - | Refl <- lemPlusSuccRight @n @m - = ELet ext idx $ - go p n (SS m) - (EFst ext (EVar ext (typeOf idx) IZ)) - (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ)) - (weakenExpr WSink acidx)) - -lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t -lemAcValArrN _ SZ = Refl -lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs new file mode 100644 index 0000000..163f1c3 --- /dev/null +++ b/src/AST/Accum.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module AST.Accum where + +import AST.Types +import Data + + +data AcPrj + = APHere + | APFst AcPrj + | APSnd AcPrj + | APLeft AcPrj + | APRight AcPrj + | APJust AcPrj + | APArrIdx AcPrj + | APArrSlice Nat + +-- | @b@ is a small part of @a@, indicated by the projection. +data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where + SAPHere :: SAcPrj APHere a a + SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b + SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b + SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b + SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b + SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +deriving instance Show (SAcPrj p a b) + +type family AcIdx p t where + AcIdx APHere t = TNil + AcIdx (APFst p) (TPair a b) = AcIdx p a + AcIdx (APSnd p) (TPair a b) = AcIdx p b + AcIdx (APLeft p) (TEither a b) = AcIdx p a + AcIdx (APRight p) (TEither a b) = AcIdx p b + AcIdx (APJust p) (TMaybe a) = AcIdx p a + AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) + AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index b7079ff..c0d8d2d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -128,8 +128,8 @@ occCountGeneral onehot unpush alter many = go WId EShape _ e -> re e EOp _ _ e -> re e ECustom _ _ _ _ _ _ _ a b -> re a <> re b - EWith _ a b -> re a <> re1 b - EAccum _ _ a b e -> re a <> re b <> re e + EWith _ _ a b -> re a <> re1 b + EAccum _ _ _ a b e -> re a <> re b <> re e EZero _ _ -> mempty EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 4b6b523..ec5e11e 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t -> zero t EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b) + EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -42,8 +42,8 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) - EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) + EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) + EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s zero :: STy t -> Ex env (D2 t) @@ -116,9 +116,13 @@ plusSparse t a b adder = (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) -onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) -onehot _ SZ _ val = val -onehot t (SS dep) idx val = case t of +onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot _ topprj arg = case topprj of + SAPHere -> arg + + SAPFst prj -> _ + +onehot t (SS dep) arg = case t of STPair t1 t2 -> case dep of SZ -> EJust ext val @@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Accumulators not allowed in input program" -onehotArrayElem - :: STy t -> SNat n -> SNat i - -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' - -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot - -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole - -> Ex env (D2 t) -onehotArrayElem t n dep eltidx idx val = - ELet ext eltidx $ - ELet ext (weakenExpr WSink idx) $ - let (cond, elt) = onehotArrayElemRec t n dep - (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) - (EVar ext (typeOf idx) IZ) - (weakenExpr (WSink .> WSink) val) - in eif cond elt (zero t) - --- AcIdx must be duplicable -onehotArrayElemRec - :: STy t -> SNat n -> SNat i - -> [Ex env TIx] - -> Ex env (AcIdx (TArr n (D2 t)) i) - -> Ex env (AcValArr n (D2 t) i) - -> (Ex env (TScal TBool), Ex env (D2 t)) -onehotArrayElemRec _ n SZ eltidx _ val = - (EConst ext STBool True - ,EIdx ext val (reconstructFromOutsideIn n eltidx)) -onehotArrayElemRec t SZ (SS dep) eltidx idx val = - case eltidx of - [] -> (EConst ext STBool True, onehot t dep idx val) - _ -> error "onehotArrayElemRec: mismatched list length" -onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = - case eltidx of - i : eltidx' -> - let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val - in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) - ,elt) - [] -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElem +-- :: STy t -> SNat n -> SNat i +-- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' +-- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot +-- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole +-- -> Ex env (D2 t) +-- onehotArrayElem t n dep eltidx idx val = +-- ELet ext eltidx $ +-- ELet ext (weakenExpr WSink idx) $ +-- let (cond, elt) = onehotArrayElemRec t n dep +-- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) +-- (EVar ext (typeOf idx) IZ) +-- (weakenExpr (WSink .> WSink) val) +-- in eif cond elt (zero t) + +-- -- AcIdx must be duplicable +-- onehotArrayElemRec +-- :: STy t -> SNat n -> SNat i +-- -> [Ex env TIx] +-- -> Ex env (AcIdx (TArr n (D2 t)) i) +-- -> Ex env (AcValArr n (D2 t) i) +-- -> (Ex env (TScal TBool), Ex env (D2 t)) +-- onehotArrayElemRec _ n SZ eltidx _ val = +-- (EConst ext STBool True +-- ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +-- onehotArrayElemRec t SZ (SS dep) eltidx idx val = +-- case eltidx of +-- [] -> (EConst ext STBool True, onehot t dep idx val) +-- _ -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = +-- case eltidx of +-- i : eltidx' -> +-- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val +-- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) +-- ,elt) +-- [] -> error "onehotArrayElemRec: mismatched list length" -- | Outermost index at the head. The input expression must be duplicable. outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] diff --git a/src/Data.hs b/src/Data.hs index 1304a5f..60afdd0 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -101,6 +101,10 @@ type family n + m where Z + m = m S n + m = S (n + m) +type family n - m where + n - Z = n + S n - S m = n - m + snatAdd :: SNat n -> SNat m -> SNat (n + m) snatAdd SZ m = m snatAdd (SS n) m = SS (snatAdd n m) |