aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/Types.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/AST/Types.hs')
-rw-r--r--src/CHAD/AST/Types.hs215
1 files changed, 215 insertions, 0 deletions
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
new file mode 100644
index 0000000..059077d
--- /dev/null
+++ b/src/CHAD/AST/Types.hs
@@ -0,0 +1,215 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeData #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.AST.Types where
+
+import Data.Int (Int32, Int64)
+import Data.GADT.Compare
+import Data.GADT.Show
+import Data.Kind (Type)
+import Data.Type.Equality
+
+import CHAD.Data
+
+
+type data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TLEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty -- ^ contained type must be a monoid type
+
+type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+
+type STy :: Ty -> Type
+data STy t where
+ STNil :: STy TNil
+ STPair :: STy a -> STy b -> STy (TPair a b)
+ STEither :: STy a -> STy b -> STy (TEither a b)
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
+ STMaybe :: STy a -> STy (TMaybe a)
+ STArr :: SNat n -> STy t -> STy (TArr n t)
+ STScal :: SScalTy t -> STy (TScal t)
+ STAccum :: SMTy t -> STy (TAccum t)
+deriving instance Show (STy t)
+
+instance GCompare STy where
+ gcompare = \cases
+ STNil STNil -> GEQ
+ STNil _ -> GLT ; _ STNil -> GGT
+ (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STPair{} _ -> GLT ; _ STPair{} -> GGT
+ (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STLEither{} _ -> GLT ; _ STLEither{} -> GGT
+ (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
+ STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
+ (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ STArr{} _ -> GLT ; _ STArr{} -> GGT
+ (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
+ STScal{} _ -> GLT ; _ STScal{} -> GGT
+ (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+
+instance TestEquality STy where testEquality = geq
+instance GEq STy where geq = defaultGeq
+instance GShow STy where gshowsPrec = defaultGshowsPrec
+
+-- | Monoid types
+type SMTy :: Ty -> Type
+data SMTy t where
+ SMTNil :: SMTy TNil
+ SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b)
+ SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b)
+ SMTMaybe :: SMTy a -> SMTy (TMaybe a)
+ SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
+ SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+deriving instance Show (SMTy t)
+
+instance GCompare SMTy where
+ gcompare = \cases
+ SMTNil SMTNil -> GEQ
+ SMTNil _ -> GLT ; _ SMTNil -> GGT
+ (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT
+ (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT
+ (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a')
+ SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT
+ (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
+ SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
+ (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
+ -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+
+instance TestEquality SMTy where testEquality = geq
+instance GEq SMTy where geq = defaultGeq
+instance GShow SMTy where gshowsPrec = defaultGshowsPrec
+
+fromSMTy :: SMTy t -> STy t
+fromSMTy = \case
+ SMTNil -> STNil
+ SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2)
+ SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2)
+ SMTMaybe t -> STMaybe (fromSMTy t)
+ SMTArr n t -> STArr n (fromSMTy t)
+ SMTScal sty -> STScal sty
+
+data SScalTy t where
+ STI32 :: SScalTy TI32
+ STI64 :: SScalTy TI64
+ STF32 :: SScalTy TF32
+ STF64 :: SScalTy TF64
+ STBool :: SScalTy TBool
+deriving instance Show (SScalTy t)
+
+instance GCompare SScalTy where
+ gcompare = \cases
+ STI32 STI32 -> GEQ
+ STI32 _ -> GLT ; _ STI32 -> GGT
+ STI64 STI64 -> GEQ
+ STI64 _ -> GLT ; _ STI64 -> GGT
+ STF32 STF32 -> GEQ
+ STF32 _ -> GLT ; _ STF32 -> GGT
+ STF64 STF64 -> GEQ
+ STF64 _ -> GLT ; _ STF64 -> GGT
+ STBool STBool -> GEQ
+ -- STBool _ -> GLT ; _ STBool -> GGT
+
+instance TestEquality SScalTy where testEquality = geq
+instance GEq SScalTy where geq = defaultGeq
+instance GShow SScalTy where gshowsPrec = defaultGshowsPrec
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
+type TIx = TScal TI64
+
+tIx :: STy TIx
+tIx = STScal STI64
+
+type family ScalRep t where
+ ScalRep TI32 = Int32
+ ScalRep TI64 = Int64
+ ScalRep TF32 = Float
+ ScalRep TF64 = Double
+ ScalRep TBool = Bool
+
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False
+
+type family ScalIsFloating t where
+ ScalIsFloating TI32 = False
+ ScalIsFloating TI64 = False
+ ScalIsFloating TF32 = True
+ ScalIsFloating TF64 = True
+ ScalIsFloating TBool = False
+
+type family ScalIsIntegral t where
+ ScalIsIntegral TI32 = True
+ ScalIsIntegral TI64 = True
+ ScalIsIntegral TF32 = False
+ ScalIsIntegral TF64 = False
+ ScalIsIntegral TBool = False
+
+-- | Returns true for arrays /and/ accumulators.
+typeHasArrays :: STy t' -> Bool
+typeHasArrays STNil = False
+typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STMaybe t) = typeHasArrays t
+typeHasArrays STArr{} = True
+typeHasArrays STScal{} = False
+typeHasArrays STAccum{} = True
+
+typeHasAccums :: STy t' -> Bool
+typeHasAccums STNil = False
+typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STMaybe t) = typeHasAccums t
+typeHasAccums STArr{} = False
+typeHasAccums STScal{} = False
+typeHasAccums STAccum{} = True
+
+type family Tup env where
+ Tup '[] = TNil
+ Tup (t : ts) = TPair (Tup ts) t
+
+mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
+ -> SList f list -> f (Tup list)
+mkTup nil _ SNil = nil
+mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
+
+tTup :: SList STy env -> STy (Tup env)
+tTup = mkTup STNil STPair
+
+unTup :: (forall a b. c (TPair a b) -> (c a, c b))
+ -> SList f list -> c (Tup list) -> SList c list
+unTup _ SNil _ = SNil
+unTup unpack (_ `SCons` list) tup =
+ let (xs, x) = unpack tup
+ in x `SCons` unTup unpack list xs
+
+type family InvTup core env where
+ InvTup core '[] = core
+ InvTup core (t : ts) = InvTup (TPair core t) ts