diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 94 |
1 files changed, 49 insertions, 45 deletions
@@ -30,7 +30,7 @@ data Ty | TEither Ty Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy - | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to + | TAccum Ty deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -43,7 +43,7 @@ data STy t where STEither :: STy a -> STy b -> STy (TEither a b) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) - STAccum :: SNat n -> STy t -> STy (TAccum n t) + STAccum :: STy t -> STy (TAccum t) deriving instance Show (STy t) data SScalTy t where @@ -66,10 +66,23 @@ type family ScalRep t where ScalRep TF64 = Double ScalRep TBool = Bool -type ConsN :: Nat -> a -> [a] -> [a] -type family ConsN n x l where - ConsN Z x l = l - ConsN (S n) x l = x : ConsN n x l +-- | This index is flipped around from the usual direction: the smallest index +-- is at the _heart_ of the nesting, not at the outside. The outermost layer +-- indexes into the _outer_ dimension of the type @t@. This makes indices into +-- compound structures work properly with coproducts. +type family AcIdx t i where + AcIdx t Z = TNil + AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) + AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) + AcIdx (TArr Z t) (S i) = AcIdx t i + AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) + +type family AcVal t i where + AcVal t Z = t + AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) + AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) + AcVal (TArr Z t) (S i) = AcVal t i + AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner @@ -91,22 +104,23 @@ data Expr x env t where -- array operations EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) + EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused + -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t + EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- accumulation effect - EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t)) - EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil + EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil + -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil -- partiality EError :: STy a -> String -> Expr x env a @@ -117,10 +131,6 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () -type family Replicate n x where - Replicate Z x = '[] - Replicate (S n) x = x : Replicate n x - type family Tup env where Tup '[] = TNil Tup (t : ts) = TPair (Tup ts) t @@ -129,6 +139,14 @@ tTup :: SList STy env -> STy (Tup env) tTup SNil = STNil tTup (SCons t ts) = STPair (tTup ts) t +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup SNil = ENil ext +eTup (e `SCons` es) = EPair ext (eTup es) e + +type family InvTup core env where + InvTup core '[] = core + InvTup core (t : ts) = InvTup (TPair core t) ts + type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -169,17 +187,17 @@ typeOf = \case EBuild _ n _ e -> STArr n (typeOf e) EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) - EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t + -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ e _ | STArr _ t <- typeOf e -> t - -- EShape _ e | STArr n _ <- typeOf e -> _ + EIdx _ _ e _ | STArr _ t <- typeOf e -> t + EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum1 _ _ _ -> STNil + EAccum _ _ _ _ -> STNil EError t _ -> t @@ -194,7 +212,7 @@ unSTy = \case STEither a b -> TEither (unSTy a) (unSTy b) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) - STAccum n t -> TAccum (unSNat n) (unSTy t) + STAccum t -> TAccum (unSTy t) unSList :: SList STy env -> [Ty] unSList SNil = [] @@ -231,17 +249,18 @@ subst' f w = \case EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b) + EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) EUnit x e -> EUnit x (subst' f w e) - EReplicate x e -> EReplicate x (subst' f w e) + -- EReplicate x e -> EReplicate x (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) + EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es) + EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -250,28 +269,9 @@ subst' f w = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i - sinkFN :: SNat n - -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t - sinkFN SZ f' x t w' i = f' x t w' i - sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) - sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i - weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -wsinkN :: SNat n -> env :> ConsN n TIx env -wsinkN SZ = WId -wsinkN (SS n) = WSink .> wsinkN n - -wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' -wcopyN SZ w = w -wcopyN (SS n) w = WCopy (wcopyN n w) - -wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env' -wpopN SZ w = w -wpopN (SS n) w = wpopN n (WPop w) - wUndoSubenv :: Subenv env env' -> env' :> env wUndoSubenv SETop = WId wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) @@ -299,11 +299,15 @@ instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair kn instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy +instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f) +ebuildUp1 n sh size f = + EBuild ext (SS n) (EPair ext sh size) $ + let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ + in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) + (EFst ext arg) |