diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-27 21:30:17 +0100 |
| commit | 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch) | |
| tree | a21c90034a02cdeb7240563dbbab355e49622d0a /src | |
| parent | ae634c056b500a568b2d89b7f8e225404a2c0c62 (diff) | |
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of
monoids to be elementwise float addition; this fundamentally clashes
with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src')
| -rw-r--r-- | src/CHAD/APIv1.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/AST.hs | 21 | ||||
| -rw-r--r-- | src/CHAD/AST.hs-boot | 15 | ||||
| -rw-r--r-- | src/CHAD/AST/Accum.hs | 34 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs | 10 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs | 12 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse.hs | 23 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 7 | ||||
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs | 102 | ||||
| -rw-r--r-- | src/CHAD/AST/Types/Ty.hs | 20 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 9 | ||||
| -rw-r--r-- | src/CHAD/AST/UnUser.hs | 102 | ||||
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs | 13 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs | 6 | ||||
| -rw-r--r-- | src/CHAD/Drev/Accum.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 8 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs | 1 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD.hs | 7 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers/Types.hs | 1 | ||||
| -rw-r--r-- | src/CHAD/Interpreter.hs | 14 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Rep.hs | 5 | ||||
| -rw-r--r-- | src/CHAD/Simplify.hs | 10 |
25 files changed, 379 insertions, 55 deletions
diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs index 73d1580..ef9b685 100644 --- a/src/CHAD/APIv1.hs +++ b/src/CHAD/APIv1.hs @@ -117,6 +117,7 @@ jvp term STI64 -> EVar ext (STScal STI64) (IS IZ) STBool -> EVar ext (STScal STBool) (IS IZ) ezipDN STAccum{} = error "jvp: Accumulators not supported in source program" + ezipDN STUser{} = error "User types not yet supported in forward AD" eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t')) eunzipDN STNil = EPair ext (ENil ext) (ENil ext) @@ -153,6 +154,7 @@ jvp term STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext) STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext) eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program" + eunzipDN STUser{} = error "User types not yet supported in forward AD" -- | Interpret an expression in a given environment. interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs index ce9eb20..eab9425 100644 --- a/src/CHAD/AST.hs +++ b/src/CHAD/AST.hs @@ -20,6 +20,7 @@ import Data.Functor.Const import Data.Functor.Identity import Data.Int (Int64) import Data.Kind (Type) +import Data.Proxy import CHAD.Array import CHAD.AST.Accum @@ -137,6 +138,10 @@ data Expr x env t where -- partiality EError :: x a -> STy a -> String -> Expr x env a + + -- user types + EUser :: x (TUser t) -> STy (TUser t) -> Expr x env (UserRep t) -> Expr x env (TUser t) + EUnUser :: x (UserRep t) -> Expr x env (TUser t) -> Expr x env (UserRep t) deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) -- | A (well-typed, well-scoped) expression using De Bruijn indices. The full @@ -271,6 +276,9 @@ typeOf = \case EError _ t _ -> t + EUser _ t _ -> t + EUnUser _ e | STUser t <- typeOf e -> userRepTy t + extOf :: Expr x env t -> x t extOf = \case EVar x _ _ -> x @@ -317,6 +325,8 @@ extOf = \case EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x + EUser x _ _ -> x + EUnUser x _ -> x mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t mapExt f = runIdentity . travExt (Identity . f) @@ -368,6 +378,8 @@ travExt f = \case EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s + EUser x t e -> EUser <$> f x <*> pure t <*> travExt f e + EUnUser x e -> EUnUser <$> f x <*> travExt f e substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t substInline repl = @@ -430,8 +442,10 @@ subst' f w = \case EZero x t e -> EZero x t (subst' f w e) EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s + EUser x t e -> EUser x t (subst' f w e) + EUnUser x e -> EUnUser x (subst' f w e) where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t @@ -458,6 +472,7 @@ instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy +instance UserPrimalType t => KnownTy (TUser t) where knownTy = STUser Proxy class KnownMTy t where knownMTy :: SMTy t instance KnownMTy TNil where knownMTy = SMTNil @@ -466,6 +481,7 @@ instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy +instance UserMonoidType t => KnownMTy (TUser t) where knownMTy = SMTUser Proxy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil @@ -480,6 +496,7 @@ styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- smtyKnown t = Dict +styKnown (STUser _) = Dict smtyKnown :: SMTy t -> Dict (KnownMTy t) smtyKnown SMTNil = Dict @@ -488,6 +505,7 @@ smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict +smtyKnown (SMTUser _) = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict @@ -657,6 +675,7 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext + go (SMTUser t) e = euserZeroInfo t (EUnUser ext e) splitSparsePair :: -- given a sparsity diff --git a/src/CHAD/AST.hs-boot b/src/CHAD/AST.hs-boot new file mode 100644 index 0000000..d1b8a62 --- /dev/null +++ b/src/CHAD/AST.hs-boot @@ -0,0 +1,15 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE StandaloneKindSignatures #-} +module CHAD.AST where + +import Data.Functor.Const (Const) +import Data.Kind (Type) + +import CHAD.AST.Types.Ty + +type role Expr representational nominal nominal +type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr x env t + +type Ex = Expr (Const ()) diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs index ea74a95..f61f00f 100644 --- a/src/CHAD/AST/Accum.hs +++ b/src/CHAD/AST/Accum.hs @@ -76,40 +76,6 @@ acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t -type family ZeroInfo t where - ZeroInfo TNil = TNil - ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) - ZeroInfo (TLEither a b) = TNil - ZeroInfo (TMaybe a) = TNil - ZeroInfo (TArr n t) = TArr n (ZeroInfo t) - ZeroInfo (TScal t) = TNil - -tZeroInfo :: SMTy t -> STy (ZeroInfo t) -tZeroInfo SMTNil = STNil -tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) -tZeroInfo (SMTLEither _ _) = STNil -tZeroInfo (SMTMaybe _) = STNil -tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) -tZeroInfo (SMTScal _) = STNil - --- | Info needed to create a zero-valued deep accumulator for a monoid type. --- Should be constructable from a D1. -type family DeepZeroInfo t where - DeepZeroInfo TNil = TNil - DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) - DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) - DeepZeroInfo (TScal t) = TNil - -tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) -tDeepZeroInfo SMTNil = STNil -tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) -tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) -tDeepZeroInfo (SMTScal _) = STNil - -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs index 46173d2..8923e13 100644 --- a/src/CHAD/AST/Count.hs +++ b/src/CHAD/AST/Count.hs @@ -880,6 +880,16 @@ occCountX initialS topexpr k = case topexpr of EError _ t msg -> k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + + EUser _ t e -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EUser ext t (mke env') + + EUnUser _ e -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EUnUser ext (mke env') where s = simplifySubstruc (typeOf topexpr) initialS diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs index 8e6b745..73f2dcc 100644 --- a/src/CHAD/AST/Env.hs +++ b/src/CHAD/AST/Env.hs @@ -11,7 +11,7 @@ module CHAD.AST.Env where import Data.Type.Equality -import CHAD.AST.Sparse +import CHAD.AST.Sparse.Types import CHAD.AST.Weaken import CHAD.Data import CHAD.Drev.Types diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index 9ddcb35..d9ac8b2 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -374,6 +374,15 @@ ppExpr' d val expr = case expr of EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) + EUser _ t@STUser{} e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ + ppApp (ppString ("user[" ++ show (typeOfProxy t) ++ "]") <> ppX expr) [e'] + + EUnUser _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "unuser" <> ppX expr) [e'] + ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc ppExprLet d val etop = do let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) @@ -421,6 +430,7 @@ ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s ppSparse (SMTScal _) SpScal = "." +ppSparse (SMTUser _) SpUser = "U" ppCommut :: Commutative -> String ppCommut Commut = "(C)" @@ -469,6 +479,7 @@ ppSTy' _ (STScal sty) = ppString $ case sty of STF64 -> "f64" STBool -> "bool" ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t +ppSTy' d (STUser t) = ppParen (d > 10) $ ppString ("User " ++ showsPrec 11 (typeOfProxy t) "") ppSMTy :: Int -> SMTy t -> String ppSMTy d ty = render $ ppSMTy' d ty @@ -485,6 +496,7 @@ ppSMTy' _ (SMTScal sty) = ppString $ case sty of STI64 -> "i64" STF32 -> "f32" STF64 -> "f64" +ppSMTy' d (SMTUser t) = ppSTy' d (STUser t) ppString :: String -> Doc x ppString = fromString diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs index 85f2882..30e6b6f 100644 --- a/src/CHAD/AST/Sparse.hs +++ b/src/CHAD/AST/Sparse.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where @@ -9,8 +11,10 @@ module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) wh import Data.Type.Equality import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Env import CHAD.AST.Sparse.Types -import CHAD.Data (SBool(..)) +import CHAD.Data sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' @@ -43,6 +47,7 @@ sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 +sparsePlus (SMTUser t) SpUser e1 e2 = EPlus ext (SMTUser t) e1 e2 cheapZero :: SMTy t -> Maybe (forall env. Ex env t) @@ -61,6 +66,19 @@ cheapZero (SMTScal t) = case t of STI64 -> Just (EConst ext t 0) STF32 -> Just (EConst ext t 0.0) STF64 -> Just (EConst ext t 0.0) +cheapZero (SMTUser t) = + let zero1 = euserZero t (EVar ext (userZeroInfo t) IZ) + occenv1 = occCountAll @_ @'[_] zero1 + zero2 = euserZero t (euserZeroInfo t (EVar ext (userRepTy t) IZ)) + occenv2 = occCountAll @_ @'[_] zero2 + in deleteUnused (userZeroInfo t `SCons` SNil) occenv1 $ \case + sub@(SENo SETop) | cheapExpr zero1 -> + Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero1))) + _ -> + deleteUnused (userRepTy t `SCons` SNil) occenv2 $ \case + sub@(SENo SETop) | cheapExpr zero2 -> + Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero2))) + _ -> Nothing data Injection sp a b where @@ -294,3 +312,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = -- scalars sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) + +-- user types +sparsePlusS _ _ (SMTUser t) SpUser SpUser k = k SpUser (Inj id) (Inj id) (EPlus ext (SMTUser t)) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs index 8f41ba4..f97a261 100644 --- a/src/CHAD/AST/Sparse/Types.hs +++ b/src/CHAD/AST/Sparse/Types.hs @@ -20,6 +20,7 @@ data Sparse t t' where SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') SpScal :: Sparse (TScal t) (TScal t) + SpUser :: Sparse (TUser t) (TUser t) deriving instance Show (Sparse t t') class ApplySparse f where @@ -33,6 +34,7 @@ instance ApplySparse STy where applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) applySparse SpScal t = t + applySparse SpUser t = t instance ApplySparse SMTy where applySparse (SpSparse s) t = SMTMaybe (applySparse s t) @@ -42,6 +44,7 @@ instance ApplySparse SMTy where applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) applySparse SpScal t = t + applySparse SpUser t = t class IsSubType s where @@ -68,6 +71,7 @@ instance IsSubType Sparse where subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) subtTrans SpScal SpScal = SpScal + subtTrans SpUser SpUser = SpUser subtFull = spDense @@ -78,6 +82,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) spDense (SMTMaybe t) = SpMaybe (spDense t) spDense (SMTArr _ t) = SpArr (spDense t) spDense (SMTScal _) = SpScal +spDense (SMTUser _) = SpUser isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') isDense SMTNil SpAbsent = Just Refl @@ -96,6 +101,7 @@ isDense (SMTArr _ t) (SpArr s) | Just Refl <- isDense t s = Just Refl | otherwise = Nothing isDense (SMTScal _) SpScal = Just Refl +isDense (SMTUser _) SpUser = Just Refl isAbsent :: Sparse t t' -> Bool isAbsent (SpSparse s) = isAbsent s @@ -105,3 +111,4 @@ isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 isAbsent (SpMaybe s) = isAbsent s isAbsent (SpArr s) = isAbsent s isAbsent SpScal = False +isAbsent SpUser = False diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs index 34267e4..75c70ea 100644 --- a/src/CHAD/AST/SplitLets.hs +++ b/src/CHAD/AST/SplitLets.hs @@ -79,6 +79,8 @@ splitLets' = \sub -> \case EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s + EUser x t e -> EUser x t (splitLets' sub e) + EUnUser x e -> EUnUser x (splitLets' sub e) where sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t @@ -167,6 +169,7 @@ split typ = case typ of STArr{} -> other STScal{} -> other STAccum{} -> other + STUser{} -> other where other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) other = (Point typ IZ, BTop) @@ -186,6 +189,7 @@ splitRec rhs typ = case typ of STArr{} -> other STScal{} -> other STAccum{} -> other + STUser{} -> other where other :: (Pointers (t : env) t, Bindings Ex env '[t]) other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs index f0feb55..bb2fcfa 100644 --- a/src/CHAD/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -1,35 +1,34 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.AST.Types where +{-# LANGUAGE UndecidableSuperClasses #-} +module CHAD.AST.Types ( + module CHAD.AST.Types.Ty, + module CHAD.AST.Types, +) where import Data.Int (Int32, Int64) import Data.GADT.Compare import Data.GADT.Show import Data.Kind (Type) +import Data.Proxy import Data.Type.Equality +import Type.Reflection import CHAD.Data +import CHAD.AST.Types.Ty +import {-# SOURCE #-} CHAD.AST -type data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TLEither Ty Ty - | TMaybe Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty -- ^ contained type must be a monoid type - -type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool -- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes -- convenient, but such scalar types are not special in any way. @@ -43,6 +42,10 @@ data STy t where STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) STAccum :: SMTy t -> STy (TAccum t) + -- the Proxy is here just to provide something that's t-indexed, which is + -- damn useful; the UNPACK is to make sure it doesn't actually result in any + -- storage here + STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t) deriving instance Show (STy t) instance GCompare STy where @@ -62,7 +65,9 @@ instance GCompare STy where (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') STScal{} _ -> GLT ; _ STScal{} -> GGT (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + STAccum{} _ -> GLT ; _ STAccum{} -> GGT + (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) + -- STUser{} _ -> GLT ; _ STUser{} -> GGT instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq @@ -77,6 +82,7 @@ data SMTy t where SMTMaybe :: SMTy a -> SMTy (TMaybe a) SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) + SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t) deriving instance Show (SMTy t) instance GCompare SMTy where @@ -92,7 +98,9 @@ instance GCompare SMTy where (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') - -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t')) + -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT instance TestEquality SMTy where testEquality = geq instance GEq SMTy where geq = defaultGeq @@ -106,6 +114,7 @@ fromSMTy = \case SMTMaybe t -> STMaybe (fromSMTy t) SMTArr n t -> STArr n (fromSMTy t) SMTScal sty -> STScal sty + SMTUser t -> STUser t data SScalTy t where STI32 :: SScalTy TI32 @@ -172,7 +181,65 @@ type family ScalIsIntegral t where ScalIsIntegral TF64 = False ScalIsIntegral TBool = False --- | Returns true for arrays /and/ accumulators. +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + ZeroInfo (TUser t) = UserZeroInfo t + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil +tZeroInfo (SMTUser t) = userZeroInfo t + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Constructable from a D1; must not have any dynamic sparsity, i.e. must have +-- the same structure as the corresponding primal. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + DeepZeroInfo (TUser t) = UserDeepZeroInfo t + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil +tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t + +class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where + type UserRep t :: Ty + type UserD2 t :: Type + userRepTy :: proxy t -> STy (UserRep t) + euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t))) + euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t))) + +class UserPrimalType t => UserMonoidType t where + type UserZeroInfo t :: Ty + type UserDeepZeroInfo t :: Ty + userZeroInfo :: proxy t -> STy (UserZeroInfo t) + userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t) + euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t) + euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t) + euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t) + -- | A deep zero must not have any dynamic sparsity. + euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t) + euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t) + +-- | Returns true for arrays /and/ accumulators. Returns False for user types. typeHasArrays :: STy t' -> Bool typeHasArrays STNil = False typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b @@ -182,6 +249,7 @@ typeHasArrays (STMaybe t) = typeHasArrays t typeHasArrays STArr{} = True typeHasArrays STScal{} = False typeHasArrays STAccum{} = True +typeHasArrays STUser{} = False typeHasAccums :: STy t' -> Bool typeHasAccums STNil = False @@ -192,6 +260,7 @@ typeHasAccums (STMaybe t) = typeHasAccums t typeHasAccums STArr{} = False typeHasAccums STScal{} = False typeHasAccums STAccum{} = True +typeHasAccums STUser{} = False type family Tup env where Tup '[] = TNil @@ -215,3 +284,6 @@ unTup unpack (_ `SCons` list) tup = type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts + +typeOfProxy :: Typeable a => proxy a -> TypeRep a +typeOfProxy _ = typeRep diff --git a/src/CHAD/AST/Types/Ty.hs b/src/CHAD/AST/Types/Ty.hs new file mode 100644 index 0000000..cee03be --- /dev/null +++ b/src/CHAD/AST/Types/Ty.hs @@ -0,0 +1,20 @@ +{-# LANGUAGE TypeData #-} +module CHAD.AST.Types.Ty where + +import Data.Kind (Type) + +import CHAD.Data (Nat) + + +type data Ty + = TNil + | TPair Ty Ty + | TEither Ty Ty + | TLEither Ty Ty + | TMaybe Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TAccum Ty -- ^ contained type must be a monoid type + | TUser Type + +type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs index d3cad25..2166fc6 100644 --- a/src/CHAD/AST/UnMonoid.hs +++ b/src/CHAD/AST/UnMonoid.hs @@ -63,6 +63,8 @@ unMonoid = \case acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s + EUser _ t e -> EUser ext t (unMonoid e) + EUnUser _ e -> EUnUser ext (unMonoid e) zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -- don't destroy the effects! @@ -78,6 +80,7 @@ zero (SMTScal t) _ = case t of STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +zero (SMTUser t) e = EUser ext (STUser t) (euserZero t e) deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t deepZero SMTNil e = elet e $ ENil ext @@ -99,6 +102,7 @@ deepZero (SMTScal t) _ = case t of STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero (SMTUser t) e = EUser ext (STUser t) (euserDeepZero t e) plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -- don't destroy the effects! @@ -136,6 +140,7 @@ plus (SMTArr _ t) a b = ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) +plus (SMTUser t) a b = EUser ext (STUser t) (euserPlus t (EUnUser ext a) (EUnUser ext b)) onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of @@ -183,8 +188,8 @@ accumulateSparse accumulateSparse topty topsp arg accum = case (topty, topsp) of (_, s) | Just Refl <- isDense topty s -> accum WId SAPHere (ENil ext) arg - (SMTScal _, SpScal) -> - accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (SMTScal _, SpScal) -> error "TScal is dense" + (SMTUser _, SpUser) -> error "TUser is dense" (_, SpSparse s) -> emaybe arg (ENil ext) diff --git a/src/CHAD/AST/UnUser.hs b/src/CHAD/AST/UnUser.hs new file mode 100644 index 0000000..73b216c --- /dev/null +++ b/src/CHAD/AST/UnUser.hs @@ -0,0 +1,102 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.UnUser where + +import CHAD.AST + + +type family UnUser t where + UnUser TNil = TNil + UnUser (TPair a b) = TPair (UnUser a) (UnUser b) + UnUser (TEither a b) = TEither (UnUser a) (UnUser b) + UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b) + UnUser (TMaybe a) = TMaybe (UnUser a) + UnUser (TArr n t) = TArr n (UnUser t) + UnUser (TScal t) = TScal t + UnUser (TAccum t) = TAccum (UnUser t) + UnUser (TUser t) = UnUser (UserRep t) + +type family UnUserE env where + UnUserE '[] = '[] + UnUserE (t : ts) = UnUser t : UnUserE ts + +unUserTy :: STy t -> STy (UnUser t) +unUserTy = \case + STNil -> STNil + STPair a b -> STPair (unUserTy a) (unUserTy b) + STEither a b -> STEither (unUserTy a) (unUserTy b) + STLEither a b -> STLEither (unUserTy a) (unUserTy b) + STMaybe t -> STMaybe (unUserTy t) + STArr n t -> STArr n (unUserTy t) + STScal t -> STScal t + STAccum t -> STAccum (unUserMTy t) + STUser t -> unUserTy (userRepTy t) + +unUserMTy :: SMTy t -> SMTy (UnUser t) +unUserMTy = \case + SMTNil -> SMTNil + SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b) + SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b) + SMTMaybe t -> SMTMaybe (unUserMTy t) + SMTArr n t -> SMTArr n (unUserMTy t) + SMTScal t -> SMTScal t + SMTUser t -> unUserMTy (userRepTy t) + +unUser :: Ex env t -> Ex (UnUserE env) (UnUser t) +unUser = \case + EUser _ _ e -> unUser e + EUnUser _ e -> unUser e + + EVar _ t i -> EVar ext t (goIdx i) + ELet _ rhs body -> ELet ext (unUser rhs) (unUser body) + EPair _ a b -> EPair ext (unUser a) (unUser b) + EFst _ e -> EFst ext (unUser e) + ESnd _ e -> ESnd ext (unUser e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (unUserTy t) (unUser e) + EInr _ t e -> EInr ext (unUserTy t) (unUser e) + ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unUser e) + EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unUser e) + ELInr _ t e -> ELInr ext t (unUser e) + ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unUser a) (unUser b) + EMap _ a b -> EMap ext (unUser a) (unUser b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c) + ESum1Inner _ e -> ESum1Inner ext (unUser e) + EUnit _ e -> EUnit ext (unUser e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e) + EReshape _ n a b -> EReshape ext n (unUser a) (unUser b) + EZip _ a b -> EZip ext (unUser a) (unUser b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unUser e) + EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b) + EIdx _ a b -> EIdx ext (unUser a) (unUser b) + EShape _ e -> EShape ext (unUser e) + EOp _ op e -> EOp ext op (unUser e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2) + ERecompute _ e -> ERecompute ext (unUser e) + EWith _ t a b -> EWith ext t (unUser a) (unUser b) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc)) + EError _ t s -> EError ext t s + + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + where + err_monoid = error "unUser: Monoid ops found" diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs index 212cc7d..284ab49 100644 --- a/src/CHAD/Analysis/Identity.hs +++ b/src/CHAD/Analysis/Identity.hs @@ -34,6 +34,7 @@ data ValId t where VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) + VIUser :: ValId (UserRep t) -> ValId (TUser t) deriving instance Show (ValId t) instance PrettyX ValId where @@ -56,6 +57,7 @@ instance PrettyX ValId where VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" VIScal i -> show i VIAccum i -> 'C' : show i + VIUser a -> 'U' : show a validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b)) validSplitEither (VIEither (Left v)) = (Just v, Nothing) @@ -386,6 +388,15 @@ idana env expr = case expr of res <- genIds t pure (res, EError res t s) + EUser _ t e -> do + (v, e') <- idana env e + pure (VIUser v, EUser (VIUser v) t e') + + EUnUser _ e -> do + (v, e') <- idana env e + let VIUser v' = v + pure (v', EUnUser v' e') + -- | This value might be either of the two arguments; we don't know which. unify :: ValId t -> ValId t -> IdGen (ValId t) unify VINil VINil = pure VINil @@ -412,6 +423,7 @@ unify (VILEither a) (VILEither b) = VILEither <$> unify a b unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j +unify (VIUser i) (VIUser j) = VIUser <$> unify i j unifyID :: Int -> Int -> IdGen Int unifyID i j | i == j = pure i @@ -426,6 +438,7 @@ genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId +genIds (STUser t) = VIUser <$> genIds (userRepTy t) shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) shidsToVec SZ _ = pure VNil diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs index bfa964b..6dc586c 100644 --- a/src/CHAD/Drev.hs +++ b/src/CHAD/Drev.hs @@ -392,6 +392,7 @@ expandSparse (STArr _ t) (SpArr s) epr e = expandSparse (STScal STF32) SpScal _ e = e expandSparse (STScal STF64) SpScal _ e = e expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" +expandSparse _ SpUser _ e = e subenvPlus :: SBool req1 -> SBool req2 -> SList SMTy env @@ -601,6 +602,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of STF64 -> False STBool -> True STAccum{} -> False + STUser{} -> False ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- @@ -1378,11 +1380,15 @@ drev des accumMap sd = \case EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + EUser _ t@STUser{} _ -> err_user ("EUser " ++ show (typeOfProxy t)) + EUnUser _ e | t@STUser{} <- typeOf e -> err_user ("EUnUser " ++ show (typeOfProxy t)) + where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" + err_user s = error $ "CHAD: operations on user types must always be provided a custom derivative with ECustom, encountered " ++ s contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs index 6f25f11..43305e6 100644 --- a/src/CHAD/Drev/Accum.hs +++ b/src/CHAD/Drev/Accum.hs @@ -21,6 +21,7 @@ d2zeroInfo STMaybe{} _ = ENil ext d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" +d2zeroInfo (STUser t) e = euserD2ZeroInfo t (EUnUser ext e) d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) d2deepZeroInfo STNil _ = ENil ext @@ -43,6 +44,7 @@ d2deepZeroInfo (STMaybe a) e = d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" +d2deepZeroInfo (STUser t) e = euserD2DeepZeroInfo t (EUnUser ext e) -- The weakening is necessary because we need to initialise the created -- accumulators with zeros. Those zeros are deep and need full primals. This diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs index 367a974..e119de2 100644 --- a/src/CHAD/Drev/Types.hs +++ b/src/CHAD/Drev/Types.hs @@ -4,7 +4,8 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Drev.Types where -import CHAD.AST.Accum +import Data.Proxy + import CHAD.AST.Types import CHAD.Data @@ -17,6 +18,7 @@ type family D1 t where D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t + D1 (TUser t) = TUser t type family D2 t where D2 TNil = TNil @@ -26,6 +28,7 @@ type family D2 t where D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t + D2 (TUser t) = TUser (UserD2 t) type family D2s t where D2s TI32 = TNil @@ -55,6 +58,7 @@ d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1 (STUser t) = STUser t d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil @@ -74,6 +78,7 @@ d2M (STScal t) = case t of STF64 -> SMTScal STF64 STBool -> SMTNil d2M STAccum{} = error "Accumulators not allowed in input program" +d2M (STUser _) = SMTUser Proxy d2 :: STy t -> STy (D2 t) d2 = fromSMTy . d2M @@ -147,6 +152,7 @@ d1Identity = \case STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" + STUser{} -> Refl d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs index 019119c..51403f5 100644 --- a/src/CHAD/Drev/Types/ToTan.hs +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -41,3 +41,4 @@ toTan typ primal der = case typ of STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" + STUser{} -> error "User types not yet supported in forward AD" diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs index 0ae88ce..933a259 100644 --- a/src/CHAD/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -57,6 +57,7 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanty STUser{} = error "User types not yet supported in forward AD" tanenv :: SList STy env -> SList STy (TanE env) tanenv SNil = SNil @@ -79,6 +80,7 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan STUser{} _ = error "User types not yet supported in forward AD" tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] @@ -97,6 +99,7 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars STUser{} _ = [] tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -128,6 +131,7 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN STUser{} _ = error "User types not yet supported in forward AD" dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -160,6 +164,7 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan STUser{} _ _ = 0.0 -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -198,6 +203,7 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst STUser{} = error "User types not yet supported in forward AD" -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -233,6 +239,7 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots ty@STUser{} x = const (zeroTan ty x) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil diff --git a/src/CHAD/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs index 540ec2b..4a07a2d 100644 --- a/src/CHAD/ForwardAD/DualNumbers.hs +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -200,10 +200,14 @@ dfwdDN = \case EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + + EUser{} -> err_user "EUser" + EUnUser{} -> err_user "EUnUser" where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" + err_user s = error $ "User types not yet supported in forward AD (" ++ s ++ ")" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) diff --git a/src/CHAD/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs index 5d5dd9e..6dcef1c 100644 --- a/src/CHAD/ForwardAD/DualNumbers/Types.hs +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -42,6 +42,7 @@ dn (STScal t) = case t of STI64 -> STScal STI64 STBool -> STScal STBool dn STAccum{} = error "Accum in source program" +dn STUser{} = error "User types not yet supported in forward AD" dne :: SList STy env -> SList STy (DNE env) dne SNil = SNil diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs index 6410b5b..8aa02d7 100644 --- a/src/CHAD/Interpreter.hs +++ b/src/CHAD/Interpreter.hs @@ -227,6 +227,8 @@ interpret'Rec env = \case b' <- interpret' env b return $ onehotM p t a' b' EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s + EUser _ _ e -> interpret' env e + EUnUser _ e -> interpret' env e interpretOp :: SOp a t -> Rep a -> Rep t interpretOp op arg = case op of @@ -267,6 +269,9 @@ zeroM typ zi = case typ of STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 + SMTUser t -> + interpretOpen False (userZeroInfo t `SCons` SNil) (Value zi `SCons` SNil) + (euserZero t (EVar ext (userZeroInfo t) IZ)) deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t deepZeroM typ zi = case typ of @@ -280,6 +285,9 @@ deepZeroM typ zi = case typ of STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 + SMTUser t -> + interpretOpen False (userDeepZeroInfo t `SCons` SNil) (Value zi `SCons` SNil) + (euserDeepZero t (EVar ext (userDeepZeroInfo t) IZ)) addM :: SMTy t -> Rep t -> Rep t -> Rep t addM typ a b = case typ of @@ -303,6 +311,9 @@ addM typ a b = case typ of | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b + SMTUser t -> + interpretOpen False (userRepTy t `SCons` userRepTy t `SCons` SNil) (Value a `SCons` Value b `SCons` SNil) + (euserPlus t (EVar ext (userRepTy t) IZ) (EVar ext (userRepTy t) (IS IZ))) onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val @@ -329,6 +340,7 @@ newAcDense typ val = case typ of SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val + SMTUser _ -> newIORef val onehotArray :: Monad m => (Rep (AcIdxS p a) -> m v) -- ^ the "one" @@ -348,6 +360,7 @@ readAc typ val = case typ of SMTMaybe t -> traverse (readAc t) =<< readIORef val SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val + SMTUser _ -> readIORef val accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () accumAddSparseD typ prj ref idx sp val = case (typ, prj) of @@ -408,6 +421,7 @@ accumAddDense typ ref sp val = case (typ, sp) of forM_ [0 .. arraySize ref - 1] $ \i -> accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + (SMTUser t, SpUser) -> AcM $ atomicModifyIORef' ref (\x -> (addM (SMTUser t) x val, ())) -- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () diff --git a/src/CHAD/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs index fadc6be..32a1a48 100644 --- a/src/CHAD/Interpreter/Rep.hs +++ b/src/CHAD/Interpreter/Rep.hs @@ -27,6 +27,7 @@ type family Rep t where Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty Rep (TAccum t) = RepAc t + Rep (TUser t) = Rep (UserRep t) -- Mutable, represents monoid types t. type family RepAc t where @@ -36,6 +37,7 @@ type family RepAc t where RepAc (TMaybe t) = IORef (Maybe (RepAc t)) RepAc (TArr n t) = Array n (RepAc t) RepAc (TScal sty) = IORef (ScalRep sty) + RepAc (TUser t) = IORef (Rep (UserRep t)) newtype Value t = Value { unValue :: Rep t } @@ -73,6 +75,8 @@ showValue d (STScal sty) x = case sty of STI64 -> showsPrec d x STBool -> showsPrec d x showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">" +showValue d (STUser t) x = + showParen (d > 10) $ showString ("User[" ++ show (typeOfProxy t) ++ "] ") . showValue 11 (userRepTy t) x showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" @@ -100,6 +104,7 @@ rnfRep (STScal t) x = case t of STF64 -> rnf x STBool -> rnf x rnfRep STAccum{} _ = error "Cannot rnf accumulators" +rnfRep (STUser t) x = rnfRep (userRepTy t) x instance KnownTy t => NFData (Value t) where rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs index ea253d6..bbc2db8 100644 --- a/src/CHAD/Simplify.hs +++ b/src/CHAD/Simplify.hs @@ -12,6 +12,8 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=50 #-} module CHAD.Simplify ( simplifyN, simplifyFix, SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, @@ -307,6 +309,9 @@ simplify'Rec = \case EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e + -- user types + EUnUser _ (EUser _ _ e) -> acted $ simplify' e + -- fallback recursion EVar _ t i -> pure $ EVar ext t i ELet _ a b -> [simprec| ELet ext *a *b |] @@ -361,6 +366,8 @@ simplify'Rec = \case EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s + EUser _ t e -> [simprec| EUser ext t *e |] + EUnUser _ e -> [simprec| EUnUser ext *e |] -- | This can be made more precise by tracking (and not counting) adds on -- locally eliminated accumulators. @@ -410,6 +417,8 @@ hasAdds = \case EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False + EUser _ _ e -> hasAdds e + EUnUser _ e -> hasAdds e checkAccumInScope :: SList STy env -> Bool checkAccumInScope = \case SNil -> False @@ -424,6 +433,7 @@ checkAccumInScope = \case SNil -> False check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True + check (STUser t) = check (userRepTy t) data OneHotTerm dense env a where OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a |
