diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-07 15:11:59 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 15:37:29 +0100 | 
| commit | 137eaa13144c2599ac29da9ebd3af24ac1ce8968 (patch) | |
| tree | 8fc5221824f671dfc27f8064e3fc537859bb73e8 /src | |
| parent | 1abb0c11efd2ba650c0a20de8047efbde2cc6adf (diff) | |
WIP revamp accumulator projection type repr
I stopped working on this because I realised that having sparse products
(and coproducts, prehaps) everywhere is a very bad idea in general, and
that we need to fix that first before really being able to do anything
else productive with performance.
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 75 | ||||
| -rw-r--r-- | src/AST/Accum.hs | 44 | ||||
| -rw-r--r-- | src/AST/Count.hs | 4 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 86 | ||||
| -rw-r--r-- | src/Data.hs | 4 | 
5 files changed, 111 insertions, 102 deletions
| @@ -16,42 +16,19 @@  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Weaken) where +module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where  import Data.Functor.Const  import Data.Kind (Type)  import Array +import AST.Accum  import AST.Types  import AST.Weaken  import CHAD.Types  import Data --- | This index is flipped around from the usual direction: the smallest index --- is at the _heart_ of the nesting, not at the outside. The outermost layer --- indexes into the _outer_ dimension of the type @t@. This makes indices into --- compound structures work properly with coproducts. -type family AcIdx t i where -  AcIdx t Z = TNil -  AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) -  AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) -  AcIdx (TMaybe t) (S i) = AcIdx t i -  AcIdx (TArr Z t) (S i) = AcIdx t i -  AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) - -type family AcVal t i where -  AcVal t Z = t -  AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) -  AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) -  AcVal (TMaybe t) (S i) = AcVal t i -  AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i)) - -type family AcValArr n t i where -  AcValArr n t Z = TArr n t -  AcValArr Z t (S i) = AcVal t i -  AcValArr (S n) t (S i) = AcValArr n t i -  -- General assumption: head of the list (whatever way it is associated) is the  -- inner variable / inner array dimension. In pretty printing, the inner  -- variable / inner dimension is printed on the _right_. @@ -110,15 +87,14 @@ data Expr x env t where            -> Expr x env a -> Expr x env b            -> Expr x env t -  -- accumulation effect -  EWith :: x (TPair a t) -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) -  EAccum :: x TNil -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil -  -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil +  -- accumulation effect on monoids +  EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum (D2 t) : env) a -> Expr x env (TPair a (D2 t)) +  EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum (D2 a)) -> Expr x env TNil    -- monoidal operations (to be desugared to regular operations after simplification)    EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)    EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) -  EOneHot :: x (D2 t) -> STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t) +  EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t)    -- partiality    EError :: x a -> STy a -> String -> Expr x env a @@ -129,9 +105,6 @@ type Ex = Expr (Const ())  ext :: Const () a  ext = Const () -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) -  type SOp :: Ty -> Ty -> Type  data SOp a t where    OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -224,8 +197,8 @@ typeOf = \case    ECustom _ _ _ _ e _ _ _ _ -> typeOf e -  EWith _ e1 e2 -> STPair (typeOf e2) (typeOf e1) -  EAccum _ _ _ _ _ -> STNil +  EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) +  EAccum _ _ _ _ _ _ -> STNil    EZero _ t -> d2 t    EPlus _ t _ _ -> d2 t @@ -262,8 +235,8 @@ extOf = \case    EShape x _ -> x    EOp x _ _ -> x    ECustom x _ _ _ _ _ _ _ _ -> x -  EWith x _ _ -> x -  EAccum x _ _ _ _ -> x +  EWith x _ _ _ -> x +  EAccum x _ _ _ _ _ -> x    EZero x _ -> x    EPlus x _ _ _ -> x    EOneHot x _ _ _ _ -> x @@ -331,11 +304,11 @@ subst' f w = \case    EShape x e -> EShape x (subst' f w e)    EOp x op e -> EOp x op (subst' f w e)    ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) -  EWith x e1 e2 -> EWith x (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) -  EAccum x i e1 e2 e3 -> EAccum x i (subst' f w e1) (subst' f w e2) (subst' f w e3) +  EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) +  EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)    EZero x t -> EZero x t    EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) -  EOneHot x t i a b -> EOneHot x t i (subst' f w a) (subst' f w b) +  EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)    EError x t s -> EError x t s    where      sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -396,6 +369,9 @@ envKnown :: SList STy env -> Dict (KnownEnv env)  envKnown SNil = Dict  envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup = mkTup (ENil ext) (EPair ext) +  ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)  ebuildUp1 n sh size f =    EBuild ext (SS n) (EPair ext sh size) $ @@ -456,22 +432,3 @@ eshapeEmpty (SS n) e =        (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))                                        (EConst ext STI64 0)))        (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) - -arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n) -arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext) -  where -    -- symbolic version of 'invert' in Interpreter -    go :: forall n m t env proxy. proxy t -> SNat n -> SNat m -       -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m)) -    go _ SZ _ _ acidx = acidx -    go p (SS n) m idx acidx -      | Refl <- lemPlusSuccRight @n @m -      = ELet ext idx $ -          go p n (SS m) -             (EFst ext (EVar ext (typeOf idx) IZ)) -             (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ)) -                        (weakenExpr WSink acidx)) - -lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t -lemAcValArrN _ SZ = Refl -lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs new file mode 100644 index 0000000..163f1c3 --- /dev/null +++ b/src/AST/Accum.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module AST.Accum where + +import AST.Types +import Data + + +data AcPrj +  = APHere +  | APFst AcPrj +  | APSnd AcPrj +  | APLeft AcPrj +  | APRight AcPrj +  | APJust AcPrj +  | APArrIdx AcPrj +  | APArrSlice Nat + +-- | @b@ is a small part of @a@, indicated by the projection. +data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where +  SAPHere :: SAcPrj APHere a a +  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b +  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b +  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b +  SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b +  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b +  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b +  SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +deriving instance Show (SAcPrj p a b) + +type family AcIdx p t where +  AcIdx APHere t = TNil +  AcIdx (APFst p) (TPair a b) = AcIdx p a +  AcIdx (APSnd p) (TPair a b) = AcIdx p b +  AcIdx (APLeft p) (TEither a b) = AcIdx p a +  AcIdx (APRight p) (TEither a b) = AcIdx p b +  AcIdx (APJust p) (TMaybe a) = AcIdx p a +  AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) +  AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index b7079ff..c0d8d2d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -128,8 +128,8 @@ occCountGeneral onehot unpush alter many = go WId        EShape _ e -> re e        EOp _ _ e -> re e        ECustom _ _ _ _ _ _ _ a b -> re a <> re b -      EWith _ a b -> re a <> re1 b -      EAccum _ _ a b e -> re a <> re b <> re e +      EWith _ _ a b -> re a <> re1 b +      EAccum _ _ _ a b e -> re a <> re b <> re e        EZero _ _ -> mempty        EPlus _ _ a b -> re a <> re b        EOneHot _ _ _ a b -> re a <> re b diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 4b6b523..ec5e11e 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t  unMonoid = \case    EZero _ t -> zero t    EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) -  EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b) +  EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)    EVar _ t i -> EVar ext t i    ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -42,8 +42,8 @@ unMonoid = \case    EShape _ e -> EShape ext (unMonoid e)    EOp _ op e -> EOp ext op (unMonoid e)    ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) -  EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) -  EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) +  EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) +  EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)    EError _ t s -> EError ext t s  zero :: STy t -> Ex env (D2 t) @@ -116,9 +116,13 @@ plusSparse t a b adder =            (EVar ext (STMaybe t) (IS IZ))))        (weakenExpr WSink a) -onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) -onehot _ SZ _ val = val -onehot t (SS dep) idx val = case t of +onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot _ topprj arg = case topprj of +  SAPHere -> arg + +  SAPFst prj -> _ + +onehot t (SS dep) arg = case t of    STPair t1 t2 ->      case dep of        SZ -> EJust ext val @@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of    STScal{} -> error "Cannot index into scalar"    STAccum{} -> error "Accumulators not allowed in input program" -onehotArrayElem -  :: STy t -> SNat n -> SNat i -  -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' -  -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot -  -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole -  -> Ex env (D2 t) -onehotArrayElem t n dep eltidx idx val = -  ELet ext eltidx $ -  ELet ext (weakenExpr WSink idx) $ -    let (cond, elt) = onehotArrayElemRec t n dep -                                         (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) -                                         (EVar ext (typeOf idx) IZ) -                                         (weakenExpr (WSink .> WSink) val) -    in eif cond elt (zero t) +-- onehotArrayElem +--   :: STy t -> SNat n -> SNat i +--   -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' +--   -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot +--   -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole +--   -> Ex env (D2 t) +-- onehotArrayElem t n dep eltidx idx val = +--   ELet ext eltidx $ +--   ELet ext (weakenExpr WSink idx) $ +--     let (cond, elt) = onehotArrayElemRec t n dep +--                                          (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) +--                                          (EVar ext (typeOf idx) IZ) +--                                          (weakenExpr (WSink .> WSink) val) +--     in eif cond elt (zero t) --- AcIdx must be duplicable -onehotArrayElemRec -  :: STy t -> SNat n -> SNat i -  -> [Ex env TIx] -  -> Ex env (AcIdx (TArr n (D2 t)) i) -  -> Ex env (AcValArr n (D2 t) i) -  -> (Ex env (TScal TBool), Ex env (D2 t)) -onehotArrayElemRec _ n SZ eltidx _ val = -  (EConst ext STBool True -  ,EIdx ext val (reconstructFromOutsideIn n eltidx)) -onehotArrayElemRec t SZ (SS dep) eltidx idx val = -  case eltidx of -    [] -> (EConst ext STBool True, onehot t dep idx val) -    _ -> error "onehotArrayElemRec: mismatched list length" -onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = -  case eltidx of -    i : eltidx' -> -      let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val -      in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) -         ,elt) -    [] -> error "onehotArrayElemRec: mismatched list length" +-- -- AcIdx must be duplicable +-- onehotArrayElemRec +--   :: STy t -> SNat n -> SNat i +--   -> [Ex env TIx] +--   -> Ex env (AcIdx (TArr n (D2 t)) i) +--   -> Ex env (AcValArr n (D2 t) i) +--   -> (Ex env (TScal TBool), Ex env (D2 t)) +-- onehotArrayElemRec _ n SZ eltidx _ val = +--   (EConst ext STBool True +--   ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +-- onehotArrayElemRec t SZ (SS dep) eltidx idx val = +--   case eltidx of +--     [] -> (EConst ext STBool True, onehot t dep idx val) +--     _ -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = +--   case eltidx of +--     i : eltidx' -> +--       let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val +--       in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) +--          ,elt) +--     [] -> error "onehotArrayElemRec: mismatched list length"  -- | Outermost index at the head. The input expression must be duplicable.  outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] diff --git a/src/Data.hs b/src/Data.hs index 1304a5f..60afdd0 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -101,6 +101,10 @@ type family n + m where    Z + m = m    S n + m = S (n + m) +type family n - m where +  n - Z = n +  S n - S m = n - m +  snatAdd :: SNat n -> SNat m -> SNat (n + m)  snatAdd SZ m = m  snatAdd (SS n) m = SS (snatAdd n m) | 
