{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableSuperClasses #-} module CHAD.AST.Types ( module CHAD.AST.Types.Ty, module CHAD.AST.Types, ) where import Data.Int (Int32, Int64) import Data.GADT.Compare import Data.GADT.Show import Data.Kind (Type) import Data.Proxy import Data.Type.Equality import Type.Reflection import CHAD.Data import CHAD.AST.Types.Ty import {-# SOURCE #-} CHAD.AST -- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes -- convenient, but such scalar types are not special in any way. type STy :: Ty -> Type data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) STLEither :: STy a -> STy b -> STy (TLEither a b) STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) -- the Proxy is here just to provide something that's t-indexed, which is -- damn useful; the UNPACK is to make sure it doesn't actually result in any -- storage here STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t) deriving instance Show (STy t) instance GCompare STy where gcompare = \cases STNil STNil -> GEQ STNil _ -> GLT ; _ STNil -> GGT (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') STPair{} _ -> GLT ; _ STPair{} -> GGT (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') STEither{} _ -> GLT ; _ STEither{} -> GGT (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') STLEither{} _ -> GLT ; _ STLEither{} -> GGT (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') STArr{} _ -> GLT ; _ STArr{} -> GGT (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') STAccum{} _ -> GLT ; _ STAccum{} -> GGT (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) -- STUser{} _ -> GLT ; _ STUser{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq instance GShow STy where gshowsPrec = defaultGshowsPrec -- | Monoid types type SMTy :: Ty -> Type data SMTy t where SMTNil :: SMTy TNil SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t) deriving instance Show (SMTy t) instance GCompare SMTy where gcompare = \cases SMTNil SMTNil -> GEQ SMTNil _ -> GLT ; _ SMTNil -> GGT (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT instance TestEquality SMTy where testEquality = geq instance GEq SMTy where geq = defaultGeq instance GShow SMTy where gshowsPrec = defaultGshowsPrec fromSMTy :: SMTy t -> STy t fromSMTy = \case SMTNil -> STNil SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) SMTMaybe t -> STMaybe (fromSMTy t) SMTArr n t -> STArr n (fromSMTy t) SMTScal sty -> STScal sty SMTUser t -> STUser t data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 STF32 :: SScalTy TF32 STF64 :: SScalTy TF64 STBool :: SScalTy TBool deriving instance Show (SScalTy t) instance GCompare SScalTy where gcompare = \cases STI32 STI32 -> GEQ STI32 _ -> GLT ; _ STI32 -> GGT STI64 STI64 -> GEQ STI64 _ -> GLT ; _ STI64 -> GGT STF32 STF32 -> GEQ STF32 _ -> GLT ; _ STF32 -> GGT STF64 STF64 -> GEQ STF64 _ -> GLT ; _ STF64 -> GGT STBool STBool -> GEQ -- STBool _ -> GLT ; _ STBool -> GGT instance TestEquality SScalTy where testEquality = geq instance GEq SScalTy where geq = defaultGeq instance GShow SScalTy where gshowsPrec = defaultGshowsPrec scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) scalRepIsShow STI32 = Dict scalRepIsShow STI64 = Dict scalRepIsShow STF32 = Dict scalRepIsShow STF64 = Dict scalRepIsShow STBool = Dict type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 ScalRep TF32 = Float ScalRep TF64 = Double ScalRep TBool = Bool type family ScalIsNumeric t where ScalIsNumeric TI32 = True ScalIsNumeric TI64 = True ScalIsNumeric TF32 = True ScalIsNumeric TF64 = True ScalIsNumeric TBool = False type family ScalIsFloating t where ScalIsFloating TI32 = False ScalIsFloating TI64 = False ScalIsFloating TF32 = True ScalIsFloating TF64 = True ScalIsFloating TBool = False type family ScalIsIntegral t where ScalIsIntegral TI32 = True ScalIsIntegral TI64 = True ScalIsIntegral TF32 = False ScalIsIntegral TF64 = False ScalIsIntegral TBool = False type family ZeroInfo t where ZeroInfo TNil = TNil ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) ZeroInfo (TLEither a b) = TNil ZeroInfo (TMaybe a) = TNil ZeroInfo (TArr n t) = TArr n (ZeroInfo t) ZeroInfo (TScal t) = TNil ZeroInfo (TUser t) = UserZeroInfo t tZeroInfo :: SMTy t -> STy (ZeroInfo t) tZeroInfo SMTNil = STNil tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) tZeroInfo (SMTLEither _ _) = STNil tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil tZeroInfo (SMTUser t) = userZeroInfo t -- | Info needed to create a zero-valued deep accumulator for a monoid type. -- Constructable from a D1; must not have any dynamic sparsity, i.e. must have -- the same structure as the corresponding primal. type family DeepZeroInfo t where DeepZeroInfo TNil = TNil DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) DeepZeroInfo (TScal t) = TNil DeepZeroInfo (TUser t) = UserDeepZeroInfo t tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) tDeepZeroInfo SMTNil = STNil tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) tDeepZeroInfo (SMTScal _) = STNil tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where type UserRep t :: Ty type UserD2 t :: Type userRepTy :: proxy t -> STy (UserRep t) euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t))) euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t))) class UserPrimalType t => UserMonoidType t where type UserZeroInfo t :: Ty type UserDeepZeroInfo t :: Ty userZeroInfo :: proxy t -> STy (UserZeroInfo t) userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t) euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t) euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t) euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t) -- | A deep zero must not have any dynamic sparsity. euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t) euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t) -- | Returns true for arrays /and/ accumulators. Returns False for user types. typeHasArrays :: STy t' -> Bool typeHasArrays STNil = False typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b typeHasArrays (STMaybe t) = typeHasArrays t typeHasArrays STArr{} = True typeHasArrays STScal{} = False typeHasArrays STAccum{} = True typeHasArrays STUser{} = False typeHasAccums :: STy t' -> Bool typeHasAccums STNil = False typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b typeHasAccums (STMaybe t) = typeHasAccums t typeHasAccums STArr{} = False typeHasAccums STScal{} = False typeHasAccums STAccum{} = True typeHasAccums STUser{} = False type family Tup env where Tup '[] = TNil Tup (t : ts) = TPair (Tup ts) t mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) -> SList f list -> f (Tup list) mkTup nil _ SNil = nil mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e tTup :: SList STy env -> STy (Tup env) tTup = mkTup STNil STPair unTup :: (forall a b. c (TPair a b) -> (c a, c b)) -> SList f list -> c (Tup list) -> SList c list unTup _ SNil _ = SNil unTup unpack (_ `SCons` list) tup = let (xs, x) = unpack tup in x `SCons` unTup unpack list xs type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts typeOfProxy :: Typeable a => proxy a -> TypeRep a typeOfProxy _ = typeRep