diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-07 15:11:59 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 15:37:29 +0100 | 
| commit | 137eaa13144c2599ac29da9ebd3af24ac1ce8968 (patch) | |
| tree | 8fc5221824f671dfc27f8064e3fc537859bb73e8 /src/AST | |
| parent | 1abb0c11efd2ba650c0a20de8047efbde2cc6adf (diff) | |
WIP revamp accumulator projection type repr
I stopped working on this because I realised that having sparse products
(and coproducts, prehaps) everywhere is a very bad idea in general, and
that we need to fix that first before really being able to do anything
else productive with performance.
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 44 | ||||
| -rw-r--r-- | src/AST/Count.hs | 4 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 86 | 
3 files changed, 91 insertions, 43 deletions
| diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs new file mode 100644 index 0000000..163f1c3 --- /dev/null +++ b/src/AST/Accum.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module AST.Accum where + +import AST.Types +import Data + + +data AcPrj +  = APHere +  | APFst AcPrj +  | APSnd AcPrj +  | APLeft AcPrj +  | APRight AcPrj +  | APJust AcPrj +  | APArrIdx AcPrj +  | APArrSlice Nat + +-- | @b@ is a small part of @a@, indicated by the projection. +data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where +  SAPHere :: SAcPrj APHere a a +  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b +  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b +  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b +  SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b +  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b +  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b +  SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +deriving instance Show (SAcPrj p a b) + +type family AcIdx p t where +  AcIdx APHere t = TNil +  AcIdx (APFst p) (TPair a b) = AcIdx p a +  AcIdx (APSnd p) (TPair a b) = AcIdx p b +  AcIdx (APLeft p) (TEither a b) = AcIdx p a +  AcIdx (APRight p) (TEither a b) = AcIdx p b +  AcIdx (APJust p) (TMaybe a) = AcIdx p a +  AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) +  AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index b7079ff..c0d8d2d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -128,8 +128,8 @@ occCountGeneral onehot unpush alter many = go WId        EShape _ e -> re e        EOp _ _ e -> re e        ECustom _ _ _ _ _ _ _ a b -> re a <> re b -      EWith _ a b -> re a <> re1 b -      EAccum _ _ a b e -> re a <> re b <> re e +      EWith _ _ a b -> re a <> re1 b +      EAccum _ _ _ a b e -> re a <> re b <> re e        EZero _ _ -> mempty        EPlus _ _ a b -> re a <> re b        EOneHot _ _ _ a b -> re a <> re b diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 4b6b523..ec5e11e 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t  unMonoid = \case    EZero _ t -> zero t    EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) -  EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b) +  EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)    EVar _ t i -> EVar ext t i    ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -42,8 +42,8 @@ unMonoid = \case    EShape _ e -> EShape ext (unMonoid e)    EOp _ op e -> EOp ext op (unMonoid e)    ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) -  EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) -  EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) +  EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) +  EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)    EError _ t s -> EError ext t s  zero :: STy t -> Ex env (D2 t) @@ -116,9 +116,13 @@ plusSparse t a b adder =            (EVar ext (STMaybe t) (IS IZ))))        (weakenExpr WSink a) -onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) -onehot _ SZ _ val = val -onehot t (SS dep) idx val = case t of +onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot _ topprj arg = case topprj of +  SAPHere -> arg + +  SAPFst prj -> _ + +onehot t (SS dep) arg = case t of    STPair t1 t2 ->      case dep of        SZ -> EJust ext val @@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of    STScal{} -> error "Cannot index into scalar"    STAccum{} -> error "Accumulators not allowed in input program" -onehotArrayElem -  :: STy t -> SNat n -> SNat i -  -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' -  -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot -  -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole -  -> Ex env (D2 t) -onehotArrayElem t n dep eltidx idx val = -  ELet ext eltidx $ -  ELet ext (weakenExpr WSink idx) $ -    let (cond, elt) = onehotArrayElemRec t n dep -                                         (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) -                                         (EVar ext (typeOf idx) IZ) -                                         (weakenExpr (WSink .> WSink) val) -    in eif cond elt (zero t) +-- onehotArrayElem +--   :: STy t -> SNat n -> SNat i +--   -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' +--   -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot +--   -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole +--   -> Ex env (D2 t) +-- onehotArrayElem t n dep eltidx idx val = +--   ELet ext eltidx $ +--   ELet ext (weakenExpr WSink idx) $ +--     let (cond, elt) = onehotArrayElemRec t n dep +--                                          (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) +--                                          (EVar ext (typeOf idx) IZ) +--                                          (weakenExpr (WSink .> WSink) val) +--     in eif cond elt (zero t) --- AcIdx must be duplicable -onehotArrayElemRec -  :: STy t -> SNat n -> SNat i -  -> [Ex env TIx] -  -> Ex env (AcIdx (TArr n (D2 t)) i) -  -> Ex env (AcValArr n (D2 t) i) -  -> (Ex env (TScal TBool), Ex env (D2 t)) -onehotArrayElemRec _ n SZ eltidx _ val = -  (EConst ext STBool True -  ,EIdx ext val (reconstructFromOutsideIn n eltidx)) -onehotArrayElemRec t SZ (SS dep) eltidx idx val = -  case eltidx of -    [] -> (EConst ext STBool True, onehot t dep idx val) -    _ -> error "onehotArrayElemRec: mismatched list length" -onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = -  case eltidx of -    i : eltidx' -> -      let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val -      in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) -         ,elt) -    [] -> error "onehotArrayElemRec: mismatched list length" +-- -- AcIdx must be duplicable +-- onehotArrayElemRec +--   :: STy t -> SNat n -> SNat i +--   -> [Ex env TIx] +--   -> Ex env (AcIdx (TArr n (D2 t)) i) +--   -> Ex env (AcValArr n (D2 t) i) +--   -> (Ex env (TScal TBool), Ex env (D2 t)) +-- onehotArrayElemRec _ n SZ eltidx _ val = +--   (EConst ext STBool True +--   ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +-- onehotArrayElemRec t SZ (SS dep) eltidx idx val = +--   case eltidx of +--     [] -> (EConst ext STBool True, onehot t dep idx val) +--     _ -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = +--   case eltidx of +--     i : eltidx' -> +--       let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val +--       in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) +--          ,elt) +--     [] -> error "onehotArrayElemRec: mismatched list length"  -- | Outermost index at the head. The input expression must be duplicable.  outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] | 
