diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 75 |
1 files changed, 16 insertions, 59 deletions
@@ -16,42 +16,19 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Weaken) where +module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) import Array +import AST.Accum import AST.Types import AST.Weaken import CHAD.Types import Data --- | This index is flipped around from the usual direction: the smallest index --- is at the _heart_ of the nesting, not at the outside. The outermost layer --- indexes into the _outer_ dimension of the type @t@. This makes indices into --- compound structures work properly with coproducts. -type family AcIdx t i where - AcIdx t Z = TNil - AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) - AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) - AcIdx (TMaybe t) (S i) = AcIdx t i - AcIdx (TArr Z t) (S i) = AcIdx t i - AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) - -type family AcVal t i where - AcVal t Z = t - AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) - AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) - AcVal (TMaybe t) (S i) = AcVal t i - AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i)) - -type family AcValArr n t i where - AcValArr n t Z = TArr n t - AcValArr Z t (S i) = AcVal t i - AcValArr (S n) t (S i) = AcValArr n t i - -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. @@ -110,15 +87,14 @@ data Expr x env t where -> Expr x env a -> Expr x env b -> Expr x env t - -- accumulation effect - EWith :: x (TPair a t) -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil - -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil + -- accumulation effect on monoids + EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum (D2 t) : env) a -> Expr x env (TPair a (D2 t)) + EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum (D2 a)) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t) + EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) -- partiality EError :: x a -> STy a -> String -> Expr x env a @@ -129,9 +105,6 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) - type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -224,8 +197,8 @@ typeOf = \case ECustom _ _ _ _ e _ _ _ _ -> typeOf e - EWith _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ -> STNil + EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ _ _ _ -> STNil EZero _ t -> d2 t EPlus _ t _ _ -> d2 t @@ -262,8 +235,8 @@ extOf = \case EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x - EWith x _ _ -> x - EAccum x _ _ _ _ -> x + EWith x _ _ _ -> x + EAccum x _ _ _ _ _ -> x EZero x _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x @@ -331,11 +304,11 @@ subst' f w = \case EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - EWith x e1 e2 -> EWith x (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x i e1 e2 e3 -> EAccum x i (subst' f w e1) (subst' f w e2) (subst' f w e3) + EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero x t -> EZero x t EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t i a b -> EOneHot x t i (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -396,6 +369,9 @@ envKnown :: SList STy env -> Dict (KnownEnv env) envKnown SNil = Dict envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup = mkTup (ENil ext) (EPair ext) + ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ @@ -456,22 +432,3 @@ eshapeEmpty (SS n) e = (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) - -arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n) -arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext) - where - -- symbolic version of 'invert' in Interpreter - go :: forall n m t env proxy. proxy t -> SNat n -> SNat m - -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m)) - go _ SZ _ _ acidx = acidx - go p (SS n) m idx acidx - | Refl <- lemPlusSuccRight @n @m - = ELet ext idx $ - go p n (SS m) - (EFst ext (EVar ext (typeOf idx) IZ)) - (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ)) - (weakenExpr WSink acidx)) - -lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t -lemAcValArrN _ SZ = Refl -lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl |