From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/CHAD/AST/Types.hs | 215 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 src/CHAD/AST/Types.hs (limited to 'src/CHAD/AST/Types.hs') diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs new file mode 100644 index 0000000..059077d --- /dev/null +++ b/src/CHAD/AST/Types.hs @@ -0,0 +1,215 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Types where + +import Data.Int (Int32, Int64) +import Data.GADT.Compare +import Data.GADT.Show +import Data.Kind (Type) +import Data.Type.Equality + +import CHAD.Data + + +type data Ty + = TNil + | TPair Ty Ty + | TEither Ty Ty + | TLEither Ty Ty + | TMaybe Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TAccum Ty -- ^ contained type must be a monoid type + +type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool + +type STy :: Ty -> Type +data STy t where + STNil :: STy TNil + STPair :: STy a -> STy b -> STy (TPair a b) + STEither :: STy a -> STy b -> STy (TEither a b) + STLEither :: STy a -> STy b -> STy (TLEither a b) + STMaybe :: STy a -> STy (TMaybe a) + STArr :: SNat n -> STy t -> STy (TArr n t) + STScal :: SScalTy t -> STy (TScal t) + STAccum :: SMTy t -> STy (TAccum t) +deriving instance Show (STy t) + +instance GCompare STy where + gcompare = \cases + STNil STNil -> GEQ + STNil _ -> GLT ; _ STNil -> GGT + (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STPair{} _ -> GLT ; _ STPair{} -> GGT + (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STEither{} _ -> GLT ; _ STEither{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STLEither{} _ -> GLT ; _ STLEither{} -> GGT + (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') + STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT + (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + STArr{} _ -> GLT ; _ STArr{} -> GGT + (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') + STScal{} _ -> GLT ; _ STScal{} -> GGT + (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') + -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + +instance TestEquality STy where testEquality = geq +instance GEq STy where geq = defaultGeq +instance GShow STy where gshowsPrec = defaultGshowsPrec + +-- | Monoid types +type SMTy :: Ty -> Type +data SMTy t where + SMTNil :: SMTy TNil + SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) + SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) + SMTMaybe :: SMTy a -> SMTy (TMaybe a) + SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) + SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) +deriving instance Show (SMTy t) + +instance GCompare SMTy where + gcompare = \cases + SMTNil SMTNil -> GEQ + SMTNil _ -> GLT ; _ SMTNil -> GGT + (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT + (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT + (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') + SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT + (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT + (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') + -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + +instance TestEquality SMTy where testEquality = geq +instance GEq SMTy where geq = defaultGeq +instance GShow SMTy where gshowsPrec = defaultGshowsPrec + +fromSMTy :: SMTy t -> STy t +fromSMTy = \case + SMTNil -> STNil + SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) + SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) + SMTMaybe t -> STMaybe (fromSMTy t) + SMTArr n t -> STArr n (fromSMTy t) + SMTScal sty -> STScal sty + +data SScalTy t where + STI32 :: SScalTy TI32 + STI64 :: SScalTy TI64 + STF32 :: SScalTy TF32 + STF64 :: SScalTy TF64 + STBool :: SScalTy TBool +deriving instance Show (SScalTy t) + +instance GCompare SScalTy where + gcompare = \cases + STI32 STI32 -> GEQ + STI32 _ -> GLT ; _ STI32 -> GGT + STI64 STI64 -> GEQ + STI64 _ -> GLT ; _ STI64 -> GGT + STF32 STF32 -> GEQ + STF32 _ -> GLT ; _ STF32 -> GGT + STF64 STF64 -> GEQ + STF64 _ -> GLT ; _ STF64 -> GGT + STBool STBool -> GEQ + -- STBool _ -> GLT ; _ STBool -> GGT + +instance TestEquality SScalTy where testEquality = geq +instance GEq SScalTy where geq = defaultGeq +instance GShow SScalTy where gshowsPrec = defaultGshowsPrec + +scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) +scalRepIsShow STI32 = Dict +scalRepIsShow STI64 = Dict +scalRepIsShow STF32 = Dict +scalRepIsShow STF64 = Dict +scalRepIsShow STBool = Dict + +type TIx = TScal TI64 + +tIx :: STy TIx +tIx = STScal STI64 + +type family ScalRep t where + ScalRep TI32 = Int32 + ScalRep TI64 = Int64 + ScalRep TF32 = Float + ScalRep TF64 = Double + ScalRep TBool = Bool + +type family ScalIsNumeric t where + ScalIsNumeric TI32 = True + ScalIsNumeric TI64 = True + ScalIsNumeric TF32 = True + ScalIsNumeric TF64 = True + ScalIsNumeric TBool = False + +type family ScalIsFloating t where + ScalIsFloating TI32 = False + ScalIsFloating TI64 = False + ScalIsFloating TF32 = True + ScalIsFloating TF64 = True + ScalIsFloating TBool = False + +type family ScalIsIntegral t where + ScalIsIntegral TI32 = True + ScalIsIntegral TI64 = True + ScalIsIntegral TF32 = False + ScalIsIntegral TF64 = False + ScalIsIntegral TBool = False + +-- | Returns true for arrays /and/ accumulators. +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True + +type family Tup env where + Tup '[] = TNil + Tup (t : ts) = TPair (Tup ts) t + +mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) + -> SList f list -> f (Tup list) +mkTup nil _ SNil = nil +mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e + +tTup :: SList STy env -> STy (Tup env) +tTup = mkTup STNil STPair + +unTup :: (forall a b. c (TPair a b) -> (c a, c b)) + -> SList f list -> c (Tup list) -> SList c list +unTup _ SNil _ = SNil +unTup unpack (_ `SCons` list) tup = + let (xs, x) = unpack tup + in x `SCons` unTup unpack list xs + +type family InvTup core env where + InvTup core '[] = core + InvTup core (t : ts) = InvTup (TPair core t) ts -- cgit v1.2.3-70-g09d2