{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module AST (module AST, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) import Data.Int import Data.Type.Equality import Array import AST.Env import AST.Weaken import Data data Ty = TNil | TPair Ty Ty | TEither Ty Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Ty deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool deriving (Show, Eq, Ord) type STy :: Ty -> Type data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: STy t -> STy (TAccum t) deriving instance Show (STy t) instance TestEquality STy where testEquality STNil STNil = Just Refl testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl testEquality _ _ = Nothing data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 STF32 :: SScalTy TF32 STF64 :: SScalTy TF64 STBool :: SScalTy TBool deriving instance Show (SScalTy t) instance TestEquality SScalTy where testEquality STI32 STI32 = Just Refl testEquality STI64 STI64 = Just Refl testEquality STF32 STF32 = Just Refl testEquality STF64 STF64 = Just Refl testEquality STBool STBool = Just Refl testEquality _ _ = Nothing scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) scalRepIsShow STI32 = Dict scalRepIsShow STI64 = Dict scalRepIsShow STF32 = Dict scalRepIsShow STF64 = Dict scalRepIsShow STBool = Dict type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 ScalRep TF32 = Float ScalRep TF64 = Double ScalRep TBool = Bool type family ScalIsNumeric t where ScalIsNumeric TI32 = True ScalIsNumeric TI64 = True ScalIsNumeric TF32 = True ScalIsNumeric TF64 = True ScalIsNumeric TBool = False -- | This index is flipped around from the usual direction: the smallest index -- is at the _heart_ of the nesting, not at the outside. The outermost layer -- indexes into the _outer_ dimension of the type @t@. This makes indices into -- compound structures work properly with coproducts. type family AcIdx t i where AcIdx t Z = TNil AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) AcIdx (TArr Z t) (S i) = AcIdx t i AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) type family AcVal t i where AcVal t Z = t AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) AcVal (TArr Z t) (S i) = AcVal t i AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus EVar :: x t -> STy t -> Idx env t -> Expr x env t ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t -- base types EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) EFst :: x a -> Expr x env (TPair a b) -> Expr x env a ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b ENil :: x TNil -> Expr x env TNil EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- accumulation effect EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil -- partiality EError :: STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) type Ex = Expr (Const ()) ext :: Const () a ext = Const () type family Tup env where Tup '[] = TNil Tup (t : ts) = TPair (Tup ts) t mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) -> SList f list -> f (Tup list) mkTup nil _ SNil = nil mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e tTup :: SList STy env -> STy (Tup env) tTup = mkTup STNil STPair eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) deriving instance Show (SOp a t) opt2 :: SOp a t -> STy t opt2 = \case OAdd t -> STScal t OMul t -> STScal t ONeg t -> STScal t OLt _ -> STScal STBool OLe _ -> STScal STBool OEq _ -> STScal STBool ONot -> STScal STBool OIf -> STEither STNil STNil typeOf :: Expr x env t -> STy t typeOf = \case EVar _ t _ -> t ELet _ _ e -> typeOf e EPair _ a b -> STPair (typeOf a) (typeOf b) EFst _ e | STPair t _ <- typeOf e -> t ESnd _ e | STPair _ t <- typeOf e -> t ENil _ -> STNil EInl _ t2 e -> STEither (typeOf e) t2 EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a EConstArr _ n t _ -> STArr n (STScal t) EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) EBuild _ n _ e -> STArr n (typeOf e) EFold1Inner _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t EIdx _ _ e _ | STArr _ t <- typeOf e -> t EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ -> STNil EError t _ -> t unSNat :: SNat n -> Nat unSNat SZ = Z unSNat (SS n) = S (unSNat n) unSTy :: STy t -> Ty unSTy = \case STNil -> TNil STPair a b -> TPair (unSTy a) (unSTy b) STEither a b -> TEither (unSTy a) (unSTy b) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) STAccum t -> TAccum (unSTy t) unSList :: SList STy env -> [Ty] unSList SNil = [] unSList (SCons t l) = unSTy t : unSList l unSScalTy :: SScalTy t -> ScalTy unSScalTy = \case STI32 -> TI32 STI64 -> TI64 STF32 -> TF32 STF64 -> TF64 STBool -> TBool subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t subst1 repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) -> Expr x env t -> Expr x env' t subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) -> env' :> envOut -> Expr x env t -> Expr x envOut t subst' f w = \case EVar x t i -> f x t w i ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) EPair x a b -> EPair x (subst' f w a) (subst' f w b) EFst x e -> EFst x (subst' f w e) ESnd x e -> ESnd x (subst' f w e) ENil x -> ENil x EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) EConstArr x n t a -> EConstArr x n t a EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1Inner x a b -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t sinkF f' x' t w' = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) wUndoSubenv :: Subenv env env' -> env' :> env wUndoSubenv SETop = WId wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub slistIdx :: SList f list -> Idx list t -> f t slistIdx (SCons x _) IZ = x slistIdx (SCons _ list) (IS i) = slistIdx list i slistIdx SNil i = case i of {} idx2int :: Idx env t -> Int idx2int IZ = 0 idx2int (IS n) = 1 + idx2int n class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 instance KnownScalTy TF32 where knownScalTy = STF32 instance KnownScalTy TF64 where knownScalTy = STF64 instance KnownScalTy TBool where knownScalTy = STBool class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- styKnown t = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict sscaltyKnown STI64 = Dict sscaltyKnown STF32 = Dict sscaltyKnown STF64 = Dict sscaltyKnown STBool = Dict ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) (EFst ext arg)