diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
| commit | 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch) | |
| tree | a21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/AST/Types.hs | |
| parent | ae634c056b500a568b2d89b7f8e225404a2c0c62 (diff) | |
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of
monoids to be elementwise float addition; this fundamentally clashes
with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/AST/Types.hs')
| -rw-r--r-- | src/CHAD/AST/Types.hs | 102 |
1 files changed, 87 insertions, 15 deletions
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs index f0feb55..bb2fcfa 100644 --- a/src/CHAD/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -1,35 +1,34 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.AST.Types where +{-# LANGUAGE UndecidableSuperClasses #-} +module CHAD.AST.Types ( + module CHAD.AST.Types.Ty, + module CHAD.AST.Types, +) where import Data.Int (Int32, Int64) import Data.GADT.Compare import Data.GADT.Show import Data.Kind (Type) +import Data.Proxy import Data.Type.Equality +import Type.Reflection import CHAD.Data +import CHAD.AST.Types.Ty +import {-# SOURCE #-} CHAD.AST -type data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TLEither Ty Ty - | TMaybe Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty -- ^ contained type must be a monoid type - -type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool -- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes -- convenient, but such scalar types are not special in any way. @@ -43,6 +42,10 @@ data STy t where STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) + -- the Proxy is here just to provide something that's t-indexed, which is + -- damn useful; the UNPACK is to make sure it doesn't actually result in any + -- storage here + STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t) deriving instance Show (STy t) instance GCompare STy where @@ -62,7 +65,9 @@ instance GCompare STy where (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + STAccum{} _ -> GLT ; _ STAccum{} -> GGT + (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) + -- STUser{} _ -> GLT ; _ STUser{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq @@ -77,6 +82,7 @@ data SMTy t where SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) + SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t) deriving instance Show (SMTy t) instance GCompare SMTy where @@ -92,7 +98,9 @@ instance GCompare SMTy where (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') - -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) + -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT instance TestEquality SMTy where testEquality = geq instance GEq SMTy where geq = defaultGeq @@ -106,6 +114,7 @@ fromSMTy = \case SMTMaybe t -> STMaybe (fromSMTy t) SMTArr n t -> STArr n (fromSMTy t) SMTScal sty -> STScal sty + SMTUser t -> STUser t data SScalTy t where STI32 :: SScalTy TI32 @@ -172,7 +181,65 @@ type family ScalIsIntegral t where ScalIsIntegral TF64 = False ScalIsIntegral TBool = False --- | Returns true for arrays /and/ accumulators. +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + ZeroInfo (TUser t) = UserZeroInfo t + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil +tZeroInfo (SMTUser t) = userZeroInfo t + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Constructable from a D1; must not have any dynamic sparsity, i.e. must have +-- the same structure as the corresponding primal. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + DeepZeroInfo (TUser t) = UserDeepZeroInfo t + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil +tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t + +class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where + type UserRep t :: Ty + type UserD2 t :: Type + userRepTy :: proxy t -> STy (UserRep t) + euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t))) + euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t))) + +class UserPrimalType t => UserMonoidType t where + type UserZeroInfo t :: Ty + type UserDeepZeroInfo t :: Ty + userZeroInfo :: proxy t -> STy (UserZeroInfo t) + userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t) + euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t) + euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t) + euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t) + -- | A deep zero must not have any dynamic sparsity. + euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t) + euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t) + +-- | Returns true for arrays /and/ accumulators. Returns False for user types. typeHasArrays :: STy t' -> Bool typeHasArrays STNil = False typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b @@ -182,6 +249,7 @@ typeHasArrays (STMaybe t) = typeHasArrays t typeHasArrays STArr{} = True typeHasArrays STScal{} = False typeHasArrays STAccum{} = True +typeHasArrays STUser{} = False typeHasAccums :: STy t' -> Bool typeHasAccums STNil = False @@ -192,6 +260,7 @@ typeHasAccums (STMaybe t) = typeHasAccums t typeHasAccums STArr{} = False typeHasAccums STScal{} = False typeHasAccums STAccum{} = True +typeHasAccums STUser{} = False type family Tup env where Tup '[] = TNil @@ -215,3 +284,6 @@ unTup unpack (_ `SCons` list) tup = type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts + +typeOfProxy :: Typeable a => proxy a -> TypeRep a +typeOfProxy _ = typeRep |
