aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/Types.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/AST/Types.hs
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/AST/Types.hs')
-rw-r--r--src/CHAD/AST/Types.hs102
1 files changed, 87 insertions, 15 deletions
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index f0feb55..bb2fcfa 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -1,35 +1,34 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module CHAD.AST.Types where
+{-# LANGUAGE UndecidableSuperClasses #-}
+module CHAD.AST.Types (
+ module CHAD.AST.Types.Ty,
+ module CHAD.AST.Types,
+) where
import Data.Int (Int32, Int64)
import Data.GADT.Compare
import Data.GADT.Show
import Data.Kind (Type)
+import Data.Proxy
import Data.Type.Equality
+import Type.Reflection
import CHAD.Data
+import CHAD.AST.Types.Ty
+import {-# SOURCE #-} CHAD.AST
-type data Ty
- = TNil
- | TPair Ty Ty
- | TEither Ty Ty
- | TLEither Ty Ty
- | TMaybe Ty
- | TArr Nat Ty -- ^ rank, element type
- | TScal ScalTy
- | TAccum Ty -- ^ contained type must be a monoid type
-
-type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes
-- convenient, but such scalar types are not special in any way.
@@ -43,6 +42,10 @@ data STy t where
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
+ -- the Proxy is here just to provide something that's t-indexed, which is
+ -- damn useful; the UNPACK is to make sure it doesn't actually result in any
+ -- storage here
+ STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t)
deriving instance Show (STy t)
instance GCompare STy where
@@ -62,7 +65,9 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t'))
+ -- STUser{} _ -> GLT ; _ STUser{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
@@ -77,6 +82,7 @@ data SMTy t where
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+ SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t)
deriving instance Show (SMTy t)
instance GCompare SMTy where
@@ -92,7 +98,9 @@ instance GCompare SMTy where
(SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
(SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
- -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+ SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+ (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t'))
+ -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT
instance TestEquality SMTy where testEquality = geq
instance GEq SMTy where geq = defaultGeq
@@ -106,6 +114,7 @@ fromSMTy = \case
SMTMaybe t -> STMaybe (fromSMTy t)
SMTArr n t -> STArr n (fromSMTy t)
SMTScal sty -> STScal sty
+ SMTUser t -> STUser t
data SScalTy t where
STI32 :: SScalTy TI32
@@ -172,7 +181,65 @@ type family ScalIsIntegral t where
ScalIsIntegral TF64 = False
ScalIsIntegral TBool = False
--- | Returns true for arrays /and/ accumulators.
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+ ZeroInfo (TUser t) = UserZeroInfo t
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+tZeroInfo (SMTUser t) = userZeroInfo t
+
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Constructable from a D1; must not have any dynamic sparsity, i.e. must have
+-- the same structure as the corresponding primal.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+ DeepZeroInfo (TUser t) = UserDeepZeroInfo t
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t
+
+class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where
+ type UserRep t :: Ty
+ type UserD2 t :: Type
+ userRepTy :: proxy t -> STy (UserRep t)
+ euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t)))
+ euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t)))
+
+class UserPrimalType t => UserMonoidType t where
+ type UserZeroInfo t :: Ty
+ type UserDeepZeroInfo t :: Ty
+ userZeroInfo :: proxy t -> STy (UserZeroInfo t)
+ userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t)
+ euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t)
+ euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t)
+ euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t)
+ -- | A deep zero must not have any dynamic sparsity.
+ euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t)
+ euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t)
+
+-- | Returns true for arrays /and/ accumulators. Returns False for user types.
typeHasArrays :: STy t' -> Bool
typeHasArrays STNil = False
typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
@@ -182,6 +249,7 @@ typeHasArrays (STMaybe t) = typeHasArrays t
typeHasArrays STArr{} = True
typeHasArrays STScal{} = False
typeHasArrays STAccum{} = True
+typeHasArrays STUser{} = False
typeHasAccums :: STy t' -> Bool
typeHasAccums STNil = False
@@ -192,6 +260,7 @@ typeHasAccums (STMaybe t) = typeHasAccums t
typeHasAccums STArr{} = False
typeHasAccums STScal{} = False
typeHasAccums STAccum{} = True
+typeHasAccums STUser{} = False
type family Tup env where
Tup '[] = TNil
@@ -215,3 +284,6 @@ unTup unpack (_ `SCons` list) tup =
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
+
+typeOfProxy :: Typeable a => proxy a -> TypeRep a
+typeOfProxy _ = typeRep