{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module AST.Types where import Data.Int (Int32, Int64) import Data.GADT.Show import Data.Kind (Type) import Data.Some import Data.Type.Equality import Data data Ty = TNil | TPair Ty Ty | TEither Ty Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy | TAccum Ty -- ^ the accumulator contains D2 of this type deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool deriving (Show, Eq, Ord) type STy :: Ty -> Type data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: STy t -> STy (TAccum t) deriving instance Show (STy t) instance TestEquality STy where testEquality STNil STNil = Just Refl testEquality STNil _ = Nothing testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl testEquality STPair{} _ = Nothing testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl testEquality STEither{} _ = Nothing testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl testEquality STMaybe{} _ = Nothing testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl testEquality STArr{} _ = Nothing testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl testEquality STScal{} _ = Nothing testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl testEquality STAccum{} _ = Nothing instance GShow STy where gshowsPrec = defaultGshowsPrec data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 STF32 :: SScalTy TF32 STF64 :: SScalTy TF64 STBool :: SScalTy TBool deriving instance Show (SScalTy t) instance TestEquality SScalTy where testEquality STI32 STI32 = Just Refl testEquality STI64 STI64 = Just Refl testEquality STF32 STF32 = Just Refl testEquality STF64 STF64 = Just Refl testEquality STBool STBool = Just Refl testEquality _ _ = Nothing instance GShow SScalTy where gshowsPrec = defaultGshowsPrec scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) scalRepIsShow STI32 = Dict scalRepIsShow STI64 = Dict scalRepIsShow STF32 = Dict scalRepIsShow STF64 = Dict scalRepIsShow STBool = Dict type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 unSTy :: STy t -> Ty unSTy = \case STNil -> TNil STPair a b -> TPair (unSTy a) (unSTy b) STEither a b -> TEither (unSTy a) (unSTy b) STMaybe t -> TMaybe (unSTy t) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) STAccum t -> TAccum (unSTy t) unSEnv :: SList STy env -> [Ty] unSEnv SNil = [] unSEnv (SCons t l) = unSTy t : unSEnv l unSScalTy :: SScalTy t -> ScalTy unSScalTy = \case STI32 -> TI32 STI64 -> TI64 STF32 -> TF32 STF64 -> TF64 STBool -> TBool reSTy :: Ty -> Some STy reSTy = \case TNil -> Some STNil TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b' TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b' TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t' TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t' TScal t | Some t' <- reSScalTy t -> Some $ STScal t' TAccum t | Some t' <- reSTy t -> Some $ STAccum t' reSEnv :: [Ty] -> Some (SList STy) reSEnv [] = Some SNil reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env) reSScalTy :: ScalTy -> Some SScalTy reSScalTy = \case TI32 -> Some STI32 TI64 -> Some STI64 TF32 -> Some STF32 TF64 -> Some STF64 TBool -> Some STBool type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 ScalRep TF32 = Float ScalRep TF64 = Double ScalRep TBool = Bool type family ScalIsNumeric t where ScalIsNumeric TI32 = True ScalIsNumeric TI64 = True ScalIsNumeric TF32 = True ScalIsNumeric TF64 = True ScalIsNumeric TBool = False type family ScalIsFloating t where ScalIsFloating TI32 = False ScalIsFloating TI64 = False ScalIsFloating TF32 = True ScalIsFloating TF64 = True ScalIsFloating TBool = False type family ScalIsIntegral t where ScalIsIntegral TI32 = True ScalIsIntegral TI64 = True ScalIsIntegral TF32 = False ScalIsIntegral TF64 = False ScalIsIntegral TBool = False -- | Returns true for arrays /and/ accumulators; hasArrays :: STy t' -> Bool hasArrays STNil = False hasArrays (STPair a b) = hasArrays a || hasArrays b hasArrays (STEither a b) = hasArrays a || hasArrays b hasArrays (STMaybe t) = hasArrays t hasArrays STArr{} = True hasArrays STScal{} = False hasArrays STAccum{} = True type family Tup env where Tup '[] = TNil Tup (t : ts) = TPair (Tup ts) t mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) -> SList f list -> f (Tup list) mkTup nil _ SNil = nil mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e tTup :: SList STy env -> STy (Tup env) tTup = mkTup STNil STPair unTup :: (forall a b. c (TPair a b) -> (c a, c b)) -> SList f list -> c (Tup list) -> SList c list unTup _ SNil _ = SNil unTup unpack (_ `SCons` list) tup = let (xs, x) = unpack tup in x `SCons` unTup unpack list xs type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts