diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
| commit | 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch) | |
| tree | a21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/AST | |
| parent | ae634c056b500a568b2d89b7f8e225404a2c0c62 (diff) | |
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of
monoids to be elementwise float addition; this fundamentally clashes
with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/AST')
| -rw-r--r-- | src/CHAD/AST/Accum.hs | 34 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs | 10 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs | 12 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse.hs | 23 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 7 | ||||
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs | 102 | ||||
| -rw-r--r-- | src/CHAD/AST/Types/Ty.hs | 20 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 9 | ||||
| -rw-r--r-- | src/CHAD/AST/UnUser.hs | 102 |
11 files changed, 272 insertions, 53 deletions
diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs index ea74a95..f61f00f 100644 --- a/src/CHAD/AST/Accum.hs +++ b/src/CHAD/AST/Accum.hs @@ -76,40 +76,6 @@ acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t -type family ZeroInfo t where - ZeroInfo TNil = TNil - ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) - ZeroInfo (TLEither a b) = TNil - ZeroInfo (TMaybe a) = TNil - ZeroInfo (TArr n t) = TArr n (ZeroInfo t) - ZeroInfo (TScal t) = TNil - -tZeroInfo :: SMTy t -> STy (ZeroInfo t) -tZeroInfo SMTNil = STNil -tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) -tZeroInfo (SMTLEither _ _) = STNil -tZeroInfo (SMTMaybe _) = STNil -tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) -tZeroInfo (SMTScal _) = STNil - --- | Info needed to create a zero-valued deep accumulator for a monoid type. --- Should be constructable from a D1. -type family DeepZeroInfo t where - DeepZeroInfo TNil = TNil - DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) - DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) - DeepZeroInfo (TScal t) = TNil - -tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) -tDeepZeroInfo SMTNil = STNil -tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) -tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) -tDeepZeroInfo (SMTScal _) = STNil - -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs index 46173d2..8923e13 100644 --- a/src/CHAD/AST/Count.hs +++ b/src/CHAD/AST/Count.hs @@ -880,6 +880,16 @@ occCountX initialS topexpr k = case topexpr of EError _ t msg -> k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + + EUser _ t e -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EUser ext t (mke env') + + EUnUser _ e -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EUnUser ext (mke env') where s = simplifySubstruc (typeOf topexpr) initialS diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs index 8e6b745..73f2dcc 100644 --- a/src/CHAD/AST/Env.hs +++ b/src/CHAD/AST/Env.hs @@ -11,7 +11,7 @@ module CHAD.AST.Env where import Data.Type.Equality -import CHAD.AST.Sparse +import CHAD.AST.Sparse.Types import CHAD.AST.Weaken import CHAD.Data import CHAD.Drev.Types diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index 9ddcb35..d9ac8b2 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -374,6 +374,15 @@ ppExpr' d val expr = case expr of EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) + EUser _ t@STUser{} e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ + ppApp (ppString ("user[" ++ show (typeOfProxy t) ++ "]") <> ppX expr) [e'] + + EUnUser _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "unuser" <> ppX expr) [e'] + ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc ppExprLet d val etop = do let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) @@ -421,6 +430,7 @@ ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s ppSparse (SMTScal _) SpScal = "." +ppSparse (SMTUser _) SpUser = "U" ppCommut :: Commutative -> String ppCommut Commut = "(C)" @@ -469,6 +479,7 @@ ppSTy' _ (STScal sty) = ppString $ case sty of STF64 -> "f64" STBool -> "bool" ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t +ppSTy' d (STUser t) = ppParen (d > 10) $ ppString ("User " ++ showsPrec 11 (typeOfProxy t) "") ppSMTy :: Int -> SMTy t -> String ppSMTy d ty = render $ ppSMTy' d ty @@ -485,6 +496,7 @@ ppSMTy' _ (SMTScal sty) = ppString $ case sty of STI64 -> "i64" STF32 -> "f32" STF64 -> "f64" +ppSMTy' d (SMTUser t) = ppSTy' d (STUser t) ppString :: String -> Doc x ppString = fromString diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs index 85f2882..30e6b6f 100644 --- a/src/CHAD/AST/Sparse.hs +++ b/src/CHAD/AST/Sparse.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where @@ -9,8 +11,10 @@ module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) wh import Data.Type.Equality import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Env import CHAD.AST.Sparse.Types -import CHAD.Data (SBool(..)) +import CHAD.Data sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' @@ -43,6 +47,7 @@ sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 +sparsePlus (SMTUser t) SpUser e1 e2 = EPlus ext (SMTUser t) e1 e2 cheapZero :: SMTy t -> Maybe (forall env. Ex env t) @@ -61,6 +66,19 @@ cheapZero (SMTScal t) = case t of STI64 -> Just (EConst ext t 0) STF32 -> Just (EConst ext t 0.0) STF64 -> Just (EConst ext t 0.0) +cheapZero (SMTUser t) = + let zero1 = euserZero t (EVar ext (userZeroInfo t) IZ) + occenv1 = occCountAll @_ @'[_] zero1 + zero2 = euserZero t (euserZeroInfo t (EVar ext (userRepTy t) IZ)) + occenv2 = occCountAll @_ @'[_] zero2 + in deleteUnused (userZeroInfo t `SCons` SNil) occenv1 $ \case + sub@(SENo SETop) | cheapExpr zero1 -> + Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero1))) + _ -> + deleteUnused (userRepTy t `SCons` SNil) occenv2 $ \case + sub@(SENo SETop) | cheapExpr zero2 -> + Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero2))) + _ -> Nothing data Injection sp a b where @@ -294,3 +312,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = -- scalars sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) + +-- user types +sparsePlusS _ _ (SMTUser t) SpUser SpUser k = k SpUser (Inj id) (Inj id) (EPlus ext (SMTUser t)) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 8f41ba4..f97a261 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -20,6 +20,7 @@ data Sparse t t' where SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') SpScal :: Sparse (TScal t) (TScal t) + SpUser :: Sparse (TUser t) (TUser t) deriving instance Show (Sparse t t') class ApplySparse f where @@ -33,6 +34,7 @@ instance ApplySparse STy where applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) applySparse SpScal t = t + applySparse SpUser t = t instance ApplySparse SMTy where applySparse (SpSparse s) t = SMTMaybe (applySparse s t) @@ -42,6 +44,7 @@ instance ApplySparse SMTy where applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) applySparse SpScal t = t + applySparse SpUser t = t class IsSubType s where @@ -68,6 +71,7 @@ instance IsSubType Sparse where subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) subtTrans SpScal SpScal = SpScal + subtTrans SpUser SpUser = SpUser subtFull = spDense @@ -78,6 +82,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) spDense (SMTMaybe t) = SpMaybe (spDense t) spDense (SMTArr _ t) = SpArr (spDense t) spDense (SMTScal _) = SpScal +spDense (SMTUser _) = SpUser isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') isDense SMTNil SpAbsent = Just Refl @@ -96,6 +101,7 @@ isDense (SMTArr _ t) (SpArr s) | Just Refl <- isDense t s = Just Refl | otherwise = Nothing isDense (SMTScal _) SpScal = Just Refl +isDense (SMTUser _) SpUser = Just Refl isAbsent :: Sparse t t' -> Bool isAbsent (SpSparse s) = isAbsent s @@ -105,3 +111,4 @@ isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s isAbsent SpScal = False +isAbsent SpUser = False diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs index 34267e4..75c70ea 100644 --- a/src/CHAD/AST/SplitLets.hs +++ b/src/CHAD/AST/SplitLets.hs @@ -79,6 +79,8 @@ splitLets' = \sub -> \case EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s + EUser x t e -> EUser x t (splitLets' sub e) + EUnUser x e -> EUnUser x (splitLets' sub e) where sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t @@ -167,6 +169,7 @@ split typ = case typ of STArr{} -> other STScal{} -> other STAccum{} -> other + STUser{} -> other where other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) other = (Point typ IZ, BTop) @@ -186,6 +189,7 @@ splitRec rhs typ = case typ of STArr{} -> other STScal{} -> other STAccum{} -> other + STUser{} -> other where other :: (Pointers (t : env) t, Bindings Ex env '[t]) other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs index f0feb55..bb2fcfa 100644 --- a/src/CHAD/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -1,35 +1,34 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.AST.Types where +{-# LANGUAGE UndecidableSuperClasses #-} +module CHAD.AST.Types ( + module CHAD.AST.Types.Ty, + module CHAD.AST.Types, +) where import Data.Int (Int32, Int64) import Data.GADT.Compare import Data.GADT.Show import Data.Kind (Type) +import Data.Proxy import Data.Type.Equality +import Type.Reflection import CHAD.Data +import CHAD.AST.Types.Ty +import {-# SOURCE #-} CHAD.AST -type data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TLEither Ty Ty - | TMaybe Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty -- ^ contained type must be a monoid type - -type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool -- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes -- convenient, but such scalar types are not special in any way. @@ -43,6 +42,10 @@ data STy t where STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) + -- the Proxy is here just to provide something that's t-indexed, which is + -- damn useful; the UNPACK is to make sure it doesn't actually result in any + -- storage here + STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t) deriving instance Show (STy t) instance GCompare STy where @@ -62,7 +65,9 @@ instance GCompare STy where (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + STAccum{} _ -> GLT ; _ STAccum{} -> GGT + (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) + -- STUser{} _ -> GLT ; _ STUser{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq @@ -77,6 +82,7 @@ data SMTy t where SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) + SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t) deriving instance Show (SMTy t) instance GCompare SMTy where @@ -92,7 +98,9 @@ instance GCompare SMTy where (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') - -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) + -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT instance TestEquality SMTy where testEquality = geq instance GEq SMTy where geq = defaultGeq @@ -106,6 +114,7 @@ fromSMTy = \case SMTMaybe t -> STMaybe (fromSMTy t) SMTArr n t -> STArr n (fromSMTy t) SMTScal sty -> STScal sty + SMTUser t -> STUser t data SScalTy t where STI32 :: SScalTy TI32 @@ -172,7 +181,65 @@ type family ScalIsIntegral t where ScalIsIntegral TF64 = False ScalIsIntegral TBool = False --- | Returns true for arrays /and/ accumulators. +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + ZeroInfo (TUser t) = UserZeroInfo t + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil +tZeroInfo (SMTUser t) = userZeroInfo t + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Constructable from a D1; must not have any dynamic sparsity, i.e. must have +-- the same structure as the corresponding primal. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + DeepZeroInfo (TUser t) = UserDeepZeroInfo t + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil +tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t + +class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where + type UserRep t :: Ty + type UserD2 t :: Type + userRepTy :: proxy t -> STy (UserRep t) + euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t))) + euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t))) + +class UserPrimalType t => UserMonoidType t where + type UserZeroInfo t :: Ty + type UserDeepZeroInfo t :: Ty + userZeroInfo :: proxy t -> STy (UserZeroInfo t) + userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t) + euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t) + euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t) + euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t) + -- | A deep zero must not have any dynamic sparsity. + euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t) + euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t) + +-- | Returns true for arrays /and/ accumulators. Returns False for user types. typeHasArrays :: STy t' -> Bool typeHasArrays STNil = False typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b @@ -182,6 +249,7 @@ typeHasArrays (STMaybe t) = typeHasArrays t typeHasArrays STArr{} = True typeHasArrays STScal{} = False typeHasArrays STAccum{} = True +typeHasArrays STUser{} = False typeHasAccums :: STy t' -> Bool typeHasAccums STNil = False @@ -192,6 +260,7 @@ typeHasAccums (STMaybe t) = typeHasAccums t typeHasAccums STArr{} = False typeHasAccums STScal{} = False typeHasAccums STAccum{} = True +typeHasAccums STUser{} = False type family Tup env where Tup '[] = TNil @@ -215,3 +284,6 @@ unTup unpack (_ `SCons` list) tup = type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts + +typeOfProxy :: Typeable a => proxy a -> TypeRep a +typeOfProxy _ = typeRep diff --git a/src/CHAD/AST/Types/Ty.hs b/src/CHAD/AST/Types/Ty.hs new file mode 100644 index 0000000..cee03be --- /dev/null +++ b/src/CHAD/AST/Types/Ty.hs @@ -0,0 +1,20 @@ +{-# LANGUAGE TypeData #-} +module CHAD.AST.Types.Ty where + +import Data.Kind (Type) + +import CHAD.Data (Nat) + + +type data Ty + = TNil + | TPair Ty Ty + | TEither Ty Ty + | TLEither Ty Ty + | TMaybe Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TAccum Ty -- ^ contained type must be a monoid type + | TUser Type + +type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index d3cad25..2166fc6 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -63,6 +63,8 @@ unMonoid = \case acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s + EUser _ t e -> EUser ext t (unMonoid e) + EUnUser _ e -> EUnUser ext (unMonoid e) zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -- don't destroy the effects! @@ -78,6 +80,7 @@ zero (SMTScal t) _ = case t of STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +zero (SMTUser t) e = EUser ext (STUser t) (euserZero t e) deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t deepZero SMTNil e = elet e $ ENil ext @@ -99,6 +102,7 @@ deepZero (SMTScal t) _ = case t of STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero (SMTUser t) e = EUser ext (STUser t) (euserDeepZero t e) plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! @@ -136,6 +140,7 @@ plus (SMTArr _ t) a b = ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) +plus (SMTUser t) a b = EUser ext (STUser t) (euserPlus t (EUnUser ext a) (EUnUser ext b)) onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of @@ -183,8 +188,8 @@ accumulateSparse accumulateSparse topty topsp arg accum = case (topty, topsp) of (_, s) | Just Refl <- isDense topty s -> accum WId SAPHere (ENil ext) arg - (SMTScal _, SpScal) -> - accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (SMTScal _, SpScal) -> error "TScal is dense" + (SMTUser _, SpUser) -> error "TUser is dense" (_, SpSparse s) -> emaybe arg (ENil ext) diff --git a/src/CHAD/AST/UnUser.hs b/src/CHAD/AST/UnUser.hs new file mode 100644 index 0000000..73b216c --- /dev/null +++ b/src/CHAD/AST/UnUser.hs @@ -0,0 +1,102 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.UnUser where + +import CHAD.AST + + +type family UnUser t where + UnUser TNil = TNil + UnUser (TPair a b) = TPair (UnUser a) (UnUser b) + UnUser (TEither a b) = TEither (UnUser a) (UnUser b) + UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b) + UnUser (TMaybe a) = TMaybe (UnUser a) + UnUser (TArr n t) = TArr n (UnUser t) + UnUser (TScal t) = TScal t + UnUser (TAccum t) = TAccum (UnUser t) + UnUser (TUser t) = UnUser (UserRep t) + +type family UnUserE env where + UnUserE '[] = '[] + UnUserE (t : ts) = UnUser t : UnUserE ts + +unUserTy :: STy t -> STy (UnUser t) +unUserTy = \case + STNil -> STNil + STPair a b -> STPair (unUserTy a) (unUserTy b) + STEither a b -> STEither (unUserTy a) (unUserTy b) + STLEither a b -> STLEither (unUserTy a) (unUserTy b) + STMaybe t -> STMaybe (unUserTy t) + STArr n t -> STArr n (unUserTy t) + STScal t -> STScal t + STAccum t -> STAccum (unUserMTy t) + STUser t -> unUserTy (userRepTy t) + +unUserMTy :: SMTy t -> SMTy (UnUser t) +unUserMTy = \case + SMTNil -> SMTNil + SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b) + SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b) + SMTMaybe t -> SMTMaybe (unUserMTy t) + SMTArr n t -> SMTArr n (unUserMTy t) + SMTScal t -> SMTScal t + SMTUser t -> unUserMTy (userRepTy t) + +unUser :: Ex env t -> Ex (UnUserE env) (UnUser t) +unUser = \case + EUser _ _ e -> unUser e + EUnUser _ e -> unUser e + + EVar _ t i -> EVar ext t (goIdx i) + ELet _ rhs body -> ELet ext (unUser rhs) (unUser body) + EPair _ a b -> EPair ext (unUser a) (unUser b) + EFst _ e -> EFst ext (unUser e) + ESnd _ e -> ESnd ext (unUser e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (unUserTy t) (unUser e) + EInr _ t e -> EInr ext (unUserTy t) (unUser e) + ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unUser e) + EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unUser e) + ELInr _ t e -> ELInr ext t (unUser e) + ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unUser a) (unUser b) + EMap _ a b -> EMap ext (unUser a) (unUser b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c) + ESum1Inner _ e -> ESum1Inner ext (unUser e) + EUnit _ e -> EUnit ext (unUser e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e) + EReshape _ n a b -> EReshape ext n (unUser a) (unUser b) + EZip _ a b -> EZip ext (unUser a) (unUser b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unUser e) + EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b) + EIdx _ a b -> EIdx ext (unUser a) (unUser b) + EShape _ e -> EShape ext (unUser e) + EOp _ op e -> EOp ext op (unUser e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2) + ERecompute _ e -> ERecompute ext (unUser e) + EWith _ t a b -> EWith ext t (unUser a) (unUser b) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc)) + EError _ t s -> EError ext t s + + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + where + err_monoid = error "unUser: Monoid ops found" |
