{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} module AST (module AST, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) import Data.Int import AST.Weaken import Data data Ty = TNil | TPair Ty Ty | TEither Ty Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool deriving (Show, Eq, Ord) type STy :: Ty -> Type data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SNat n -> STy t -> STy (TAccum n t) deriving instance Show (STy t) data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 STF32 :: SScalTy TF32 STF64 :: SScalTy TF64 STBool :: SScalTy TBool deriving instance Show (SScalTy t) type TIx = TScal TI64 type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 ScalRep TF32 = Float ScalRep TF64 = Double ScalRep TBool = Bool type ConsN :: Nat -> a -> [a] -> [a] type family ConsN n x l where ConsN Z x l = l ConsN (S n) x l = x : ConsN n x l -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus EVar :: x t -> STy t -> Idx env t -> Expr x env t ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t -- base types EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) EFst :: x a -> Expr x env (TPair a b) -> Expr x env a ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b ENil :: x TNil -> Expr x env TNil EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- accumulation effect EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t)) EAccum :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum n t) -> Expr x env TNil -- partiality EError :: STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) type Ex = Expr (Const ()) ext :: Const () a ext = Const () type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) ONeg :: SScalTy a -> SOp (TScal a) (TScal a) OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) deriving instance Show (SOp a t) opt2 :: SOp a t -> STy t opt2 = \case OAdd t -> STScal t OMul t -> STScal t ONeg t -> STScal t OLt _ -> STScal STBool OLe _ -> STScal STBool OEq _ -> STScal STBool ONot -> STScal STBool OIf -> STEither STNil STNil typeOf :: Expr x env t -> STy t typeOf = \case EVar _ t _ -> t ELet _ _ e -> typeOf e EPair _ a b -> STPair (typeOf a) (typeOf b) EFst _ e | STPair t _ <- typeOf e -> t ESnd _ e | STPair _ t <- typeOf e -> t ENil _ -> STNil EInl _ t2 e -> STEither (typeOf e) t2 EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) EBuild _ es e -> STArr (vecLength es) (typeOf e) EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t EConst _ t _ -> STScal t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t EIdx _ e _ | STArr _ t <- typeOf e -> t EOp _ op _ -> opt2 op EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ -> STNil EError t _ -> t unSNat :: SNat n -> Nat unSNat SZ = Z unSNat (SS n) = S (unSNat n) unSTy :: STy t -> Ty unSTy = \case STNil -> TNil STPair a b -> TPair (unSTy a) (unSTy b) STEither a b -> TEither (unSTy a) (unSTy b) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) STAccum n t -> TAccum (unSNat n) (unSTy t) unSList :: SList STy env -> [Ty] unSList SNil = [] unSList (SCons t l) = unSTy t : unSList l unSScalTy :: SScalTy t -> ScalTy unSScalTy = \case STI32 -> TI32 STI64 -> TI64 STF32 -> TF32 STF64 -> TF64 STBool -> TBool subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t subst1 repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) -> Expr x env t -> Expr x env' t subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) -> env' :> envOut -> Expr x env t -> Expr x envOut t subst' f w = \case EVar x t i -> f x t w i ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) EPair x a b -> EPair x (subst' f w a) (subst' f w b) EFst x e -> EFst x (subst' f w e) ESnd x e -> ESnd x (subst' f w e) ENil x -> ENil x EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) EConst x t v -> EConst x t v EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t sinkF f' x' t w' = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i sinkFN :: SNat n -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t sinkFN SZ f' x t w' i = f' x t w' i sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) wsinkN :: SNat n -> env :> ConsN n TIx env wsinkN SZ = WId wsinkN (SS n) = WSink .> wsinkN n wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' wcopyN SZ w = w wcopyN (SS n) w = WCopy (wcopyN n w) wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env' wpopN SZ w = w wpopN (SS n) w = wpopN n (WPop w) slistIdx :: SList f list -> Idx list t -> f t slistIdx (SCons x _) IZ = x slistIdx (SCons _ list) (IS i) = slistIdx list i slistIdx SNil i = case i of {} idx2int :: Idx env t -> Int idx2int IZ = 0 idx2int (IS n) = 1 + idx2int n class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 instance KnownScalTy TF32 where knownScalTy = STF32 instance KnownScalTy TF64 where knownScalTy = STF64 instance KnownScalTy TBool where knownScalTy = STBool class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv