summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-07 15:11:59 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-14 15:37:29 +0100
commit137eaa13144c2599ac29da9ebd3af24ac1ce8968 (patch)
tree8fc5221824f671dfc27f8064e3fc537859bb73e8
parent1abb0c11efd2ba650c0a20de8047efbde2cc6adf (diff)
WIP revamp accumulator projection type repr
I stopped working on this because I realised that having sparse products (and coproducts, prehaps) everywhere is a very bad idea in general, and that we need to fix that first before really being able to do anything else productive with performance.
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs75
-rw-r--r--src/AST/Accum.hs44
-rw-r--r--src/AST/Count.hs4
-rw-r--r--src/AST/UnMonoid.hs88
-rw-r--r--src/Data.hs4
6 files changed, 113 insertions, 103 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index aa4dfcc..e201683 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -13,6 +13,7 @@ library
Analysis.Identity
Array
AST
+ AST.Accum
AST.Bindings
AST.Count
AST.Env
diff --git a/src/AST.hs b/src/AST.hs
index 0e040d4..3fb8822 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -16,42 +16,19 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-module AST (module AST, module AST.Types, module AST.Weaken) where
+module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
import Data.Kind (Type)
import Array
+import AST.Accum
import AST.Types
import AST.Weaken
import CHAD.Types
import Data
--- | This index is flipped around from the usual direction: the smallest index
--- is at the _heart_ of the nesting, not at the outside. The outermost layer
--- indexes into the _outer_ dimension of the type @t@. This makes indices into
--- compound structures work properly with coproducts.
-type family AcIdx t i where
- AcIdx t Z = TNil
- AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
- AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
- AcIdx (TMaybe t) (S i) = AcIdx t i
- AcIdx (TArr Z t) (S i) = AcIdx t i
- AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i)
-
-type family AcVal t i where
- AcVal t Z = t
- AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
- AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
- AcVal (TMaybe t) (S i) = AcVal t i
- AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i))
-
-type family AcValArr n t i where
- AcValArr n t Z = TArr n t
- AcValArr Z t (S i) = AcVal t i
- AcValArr (S n) t (S i) = AcValArr n t i
-
-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
@@ -110,15 +87,14 @@ data Expr x env t where
-> Expr x env a -> Expr x env b
-> Expr x env t
- -- accumulation effect
- EWith :: x (TPair a t) -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
- EAccum :: x TNil -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
- -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
+ -- accumulation effect on monoids
+ EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum (D2 t) : env) a -> Expr x env (TPair a (D2 t))
+ EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum (D2 a)) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)
EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
- EOneHot :: x (D2 t) -> STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t)
+ EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t)
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
@@ -129,9 +105,6 @@ type Ex = Expr (Const ())
ext :: Const () a
ext = Const ()
-eTup :: SList (Ex env) list -> Ex env (Tup list)
-eTup = mkTup (ENil ext) (EPair ext)
-
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
@@ -224,8 +197,8 @@ typeOf = \case
ECustom _ _ _ _ e _ _ _ _ -> typeOf e
- EWith _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ _ -> STNil
+ EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
+ EAccum _ _ _ _ _ _ -> STNil
EZero _ t -> d2 t
EPlus _ t _ _ -> d2 t
@@ -262,8 +235,8 @@ extOf = \case
EShape x _ -> x
EOp x _ _ -> x
ECustom x _ _ _ _ _ _ _ _ -> x
- EWith x _ _ -> x
- EAccum x _ _ _ _ -> x
+ EWith x _ _ _ -> x
+ EAccum x _ _ _ _ _ -> x
EZero x _ -> x
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
@@ -331,11 +304,11 @@ subst' f w = \case
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
- EWith x e1 e2 -> EWith x (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum x i e1 e2 e3 -> EAccum x i (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
EZero x t -> EZero x t
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t i a b -> EOneHot x t i (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -396,6 +369,9 @@ envKnown :: SList STy env -> Dict (KnownEnv env)
envKnown SNil = Dict
envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
+eTup :: SList (Ex env) list -> Ex env (Tup list)
+eTup = mkTup (ENil ext) (EPair ext)
+
ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
ebuildUp1 n sh size f =
EBuild ext (SS n) (EPair ext sh size) $
@@ -456,22 +432,3 @@ eshapeEmpty (SS n) e =
(EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
(EConst ext STI64 0)))
(eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
-
-arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n)
-arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext)
- where
- -- symbolic version of 'invert' in Interpreter
- go :: forall n m t env proxy. proxy t -> SNat n -> SNat m
- -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m))
- go _ SZ _ _ acidx = acidx
- go p (SS n) m idx acidx
- | Refl <- lemPlusSuccRight @n @m
- = ELet ext idx $
- go p n (SS m)
- (EFst ext (EVar ext (typeOf idx) IZ))
- (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ))
- (weakenExpr WSink acidx))
-
-lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t
-lemAcValArrN _ SZ = Refl
-lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
new file mode 100644
index 0000000..163f1c3
--- /dev/null
+++ b/src/AST/Accum.hs
@@ -0,0 +1,44 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+module AST.Accum where
+
+import AST.Types
+import Data
+
+
+data AcPrj
+ = APHere
+ | APFst AcPrj
+ | APSnd AcPrj
+ | APLeft AcPrj
+ | APRight AcPrj
+ | APJust AcPrj
+ | APArrIdx AcPrj
+ | APArrSlice Nat
+
+-- | @b@ is a small part of @a@, indicated by the projection.
+data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
+ SAPHere :: SAcPrj APHere a a
+ SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b
+ SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b
+ SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
+ SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b
+ SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b
+ SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
+deriving instance Show (SAcPrj p a b)
+
+type family AcIdx p t where
+ AcIdx APHere t = TNil
+ AcIdx (APFst p) (TPair a b) = AcIdx p a
+ AcIdx (APSnd p) (TPair a b) = AcIdx p b
+ AcIdx (APLeft p) (TEither a b) = AcIdx p a
+ AcIdx (APRight p) (TEither a b) = AcIdx p b
+ AcIdx (APJust p) (TMaybe a) = AcIdx p a
+ AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a)
+ AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index b7079ff..c0d8d2d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -128,8 +128,8 @@ occCountGeneral onehot unpush alter many = go WId
EShape _ e -> re e
EOp _ _ e -> re e
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
- EWith _ a b -> re a <> re1 b
- EAccum _ _ a b e -> re a <> re b <> re e
+ EWith _ _ a b -> re a <> re1 b
+ EAccum _ _ _ a b e -> re a <> re b <> re e
EZero _ _ -> mempty
EPlus _ _ a b -> re a <> re b
EOneHot _ _ _ a b -> re a <> re b
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 4b6b523..ec5e11e 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t -> zero t
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
- EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b)
+ EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
@@ -42,8 +42,8 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
- EWith _ a b -> EWith ext (unMonoid a) (unMonoid b)
- EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e)
+ EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
+ EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
zero :: STy t -> Ex env (D2 t)
@@ -116,9 +116,13 @@ plusSparse t a b adder =
(EVar ext (STMaybe t) (IS IZ))))
(weakenExpr WSink a)
-onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t)
-onehot _ SZ _ val = val
-onehot t (SS dep) idx val = case t of
+onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
+onehot _ topprj arg = case topprj of
+ SAPHere -> arg
+
+ SAPFst prj -> _
+
+onehot t (SS dep) arg = case t of
STPair t1 t2 ->
case dep of
SZ -> EJust ext val
@@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of
STScal{} -> error "Cannot index into scalar"
STAccum{} -> error "Accumulators not allowed in input program"
-onehotArrayElem
- :: STy t -> SNat n -> SNat i
- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
- -> Ex env (D2 t)
-onehotArrayElem t n dep eltidx idx val =
- ELet ext eltidx $
- ELet ext (weakenExpr WSink idx) $
- let (cond, elt) = onehotArrayElemRec t n dep
- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
- (EVar ext (typeOf idx) IZ)
- (weakenExpr (WSink .> WSink) val)
- in eif cond elt (zero t)
-
--- AcIdx must be duplicable
-onehotArrayElemRec
- :: STy t -> SNat n -> SNat i
- -> [Ex env TIx]
- -> Ex env (AcIdx (TArr n (D2 t)) i)
- -> Ex env (AcValArr n (D2 t) i)
- -> (Ex env (TScal TBool), Ex env (D2 t))
-onehotArrayElemRec _ n SZ eltidx _ val =
- (EConst ext STBool True
- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
-onehotArrayElemRec t SZ (SS dep) eltidx idx val =
- case eltidx of
- [] -> (EConst ext STBool True, onehot t dep idx val)
- _ -> error "onehotArrayElemRec: mismatched list length"
-onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
- case eltidx of
- i : eltidx' ->
- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
- ,elt)
- [] -> error "onehotArrayElemRec: mismatched list length"
+-- onehotArrayElem
+-- :: STy t -> SNat n -> SNat i
+-- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
+-- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
+-- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
+-- -> Ex env (D2 t)
+-- onehotArrayElem t n dep eltidx idx val =
+-- ELet ext eltidx $
+-- ELet ext (weakenExpr WSink idx) $
+-- let (cond, elt) = onehotArrayElemRec t n dep
+-- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
+-- (EVar ext (typeOf idx) IZ)
+-- (weakenExpr (WSink .> WSink) val)
+-- in eif cond elt (zero t)
+
+-- -- AcIdx must be duplicable
+-- onehotArrayElemRec
+-- :: STy t -> SNat n -> SNat i
+-- -> [Ex env TIx]
+-- -> Ex env (AcIdx (TArr n (D2 t)) i)
+-- -> Ex env (AcValArr n (D2 t) i)
+-- -> (Ex env (TScal TBool), Ex env (D2 t))
+-- onehotArrayElemRec _ n SZ eltidx _ val =
+-- (EConst ext STBool True
+-- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
+-- onehotArrayElemRec t SZ (SS dep) eltidx idx val =
+-- case eltidx of
+-- [] -> (EConst ext STBool True, onehot t dep idx val)
+-- _ -> error "onehotArrayElemRec: mismatched list length"
+-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
+-- case eltidx of
+-- i : eltidx' ->
+-- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
+-- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
+-- ,elt)
+-- [] -> error "onehotArrayElemRec: mismatched list length"
-- | Outermost index at the head. The input expression must be duplicable.
outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
diff --git a/src/Data.hs b/src/Data.hs
index 1304a5f..60afdd0 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -101,6 +101,10 @@ type family n + m where
Z + m = m
S n + m = S (n + m)
+type family n - m where
+ n - Z = n
+ S n - S m = n - m
+
snatAdd :: SNat n -> SNat m -> SNat (n + m)
snatAdd SZ m = m
snatAdd (SS n) m = SS (snatAdd n m)