diff options
Diffstat (limited to 'src/AST')
-rw-r--r-- | src/AST/Accum.hs | 44 | ||||
-rw-r--r-- | src/AST/Count.hs | 4 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 88 |
3 files changed, 92 insertions, 44 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs new file mode 100644 index 0000000..163f1c3 --- /dev/null +++ b/src/AST/Accum.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module AST.Accum where + +import AST.Types +import Data + + +data AcPrj + = APHere + | APFst AcPrj + | APSnd AcPrj + | APLeft AcPrj + | APRight AcPrj + | APJust AcPrj + | APArrIdx AcPrj + | APArrSlice Nat + +-- | @b@ is a small part of @a@, indicated by the projection. +data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where + SAPHere :: SAcPrj APHere a a + SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b + SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b + SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b + SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b + SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +deriving instance Show (SAcPrj p a b) + +type family AcIdx p t where + AcIdx APHere t = TNil + AcIdx (APFst p) (TPair a b) = AcIdx p a + AcIdx (APSnd p) (TPair a b) = AcIdx p b + AcIdx (APLeft p) (TEither a b) = AcIdx p a + AcIdx (APRight p) (TEither a b) = AcIdx p b + AcIdx (APJust p) (TMaybe a) = AcIdx p a + AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) + AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index b7079ff..c0d8d2d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -128,8 +128,8 @@ occCountGeneral onehot unpush alter many = go WId EShape _ e -> re e EOp _ _ e -> re e ECustom _ _ _ _ _ _ _ a b -> re a <> re b - EWith _ a b -> re a <> re1 b - EAccum _ _ a b e -> re a <> re b <> re e + EWith _ _ a b -> re a <> re1 b + EAccum _ _ _ a b e -> re a <> re b <> re e EZero _ _ -> mempty EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 4b6b523..ec5e11e 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t -> zero t EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b) + EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -42,8 +42,8 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) - EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) + EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) + EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s zero :: STy t -> Ex env (D2 t) @@ -116,9 +116,13 @@ plusSparse t a b adder = (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) -onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) -onehot _ SZ _ val = val -onehot t (SS dep) idx val = case t of +onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot _ topprj arg = case topprj of + SAPHere -> arg + + SAPFst prj -> _ + +onehot t (SS dep) arg = case t of STPair t1 t2 -> case dep of SZ -> EJust ext val @@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Accumulators not allowed in input program" -onehotArrayElem - :: STy t -> SNat n -> SNat i - -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' - -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot - -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole - -> Ex env (D2 t) -onehotArrayElem t n dep eltidx idx val = - ELet ext eltidx $ - ELet ext (weakenExpr WSink idx) $ - let (cond, elt) = onehotArrayElemRec t n dep - (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) - (EVar ext (typeOf idx) IZ) - (weakenExpr (WSink .> WSink) val) - in eif cond elt (zero t) - --- AcIdx must be duplicable -onehotArrayElemRec - :: STy t -> SNat n -> SNat i - -> [Ex env TIx] - -> Ex env (AcIdx (TArr n (D2 t)) i) - -> Ex env (AcValArr n (D2 t) i) - -> (Ex env (TScal TBool), Ex env (D2 t)) -onehotArrayElemRec _ n SZ eltidx _ val = - (EConst ext STBool True - ,EIdx ext val (reconstructFromOutsideIn n eltidx)) -onehotArrayElemRec t SZ (SS dep) eltidx idx val = - case eltidx of - [] -> (EConst ext STBool True, onehot t dep idx val) - _ -> error "onehotArrayElemRec: mismatched list length" -onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = - case eltidx of - i : eltidx' -> - let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val - in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) - ,elt) - [] -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElem +-- :: STy t -> SNat n -> SNat i +-- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' +-- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot +-- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole +-- -> Ex env (D2 t) +-- onehotArrayElem t n dep eltidx idx val = +-- ELet ext eltidx $ +-- ELet ext (weakenExpr WSink idx) $ +-- let (cond, elt) = onehotArrayElemRec t n dep +-- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) +-- (EVar ext (typeOf idx) IZ) +-- (weakenExpr (WSink .> WSink) val) +-- in eif cond elt (zero t) + +-- -- AcIdx must be duplicable +-- onehotArrayElemRec +-- :: STy t -> SNat n -> SNat i +-- -> [Ex env TIx] +-- -> Ex env (AcIdx (TArr n (D2 t)) i) +-- -> Ex env (AcValArr n (D2 t) i) +-- -> (Ex env (TScal TBool), Ex env (D2 t)) +-- onehotArrayElemRec _ n SZ eltidx _ val = +-- (EConst ext STBool True +-- ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +-- onehotArrayElemRec t SZ (SS dep) eltidx idx val = +-- case eltidx of +-- [] -> (EConst ext STBool True, onehot t dep idx val) +-- _ -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = +-- case eltidx of +-- i : eltidx' -> +-- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val +-- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) +-- ,elt) +-- [] -> error "onehotArrayElemRec: mismatched list length" -- | Outermost index at the head. The input expression must be duplicable. outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] |