aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/CHAD/APIv1.hs2
-rw-r--r--src/CHAD/AST.hs21
-rw-r--r--src/CHAD/AST.hs-boot15
-rw-r--r--src/CHAD/AST/Accum.hs34
-rw-r--r--src/CHAD/AST/Count.hs10
-rw-r--r--src/CHAD/AST/Env.hs2
-rw-r--r--src/CHAD/AST/Pretty.hs12
-rw-r--r--src/CHAD/AST/Sparse.hs23
-rw-r--r--src/CHAD/AST/Sparse/Types.hs7
-rw-r--r--src/CHAD/AST/SplitLets.hs4
-rw-r--r--src/CHAD/AST/Types.hs102
-rw-r--r--src/CHAD/AST/Types/Ty.hs20
-rw-r--r--src/CHAD/AST/UnMonoid.hs9
-rw-r--r--src/CHAD/AST/UnUser.hs102
-rw-r--r--src/CHAD/Analysis/Identity.hs13
-rw-r--r--src/CHAD/Drev.hs6
-rw-r--r--src/CHAD/Drev/Accum.hs2
-rw-r--r--src/CHAD/Drev/Types.hs8
-rw-r--r--src/CHAD/Drev/Types/ToTan.hs1
-rw-r--r--src/CHAD/ForwardAD.hs7
-rw-r--r--src/CHAD/ForwardAD/DualNumbers.hs4
-rw-r--r--src/CHAD/ForwardAD/DualNumbers/Types.hs1
-rw-r--r--src/CHAD/Interpreter.hs14
-rw-r--r--src/CHAD/Interpreter/Rep.hs5
-rw-r--r--src/CHAD/Simplify.hs10
26 files changed, 381 insertions, 55 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 1eef3ed..5d800bc 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -26,7 +26,9 @@ library
CHAD.AST.Sparse.Types
CHAD.AST.SplitLets
CHAD.AST.Types
+ CHAD.AST.Types.Ty
CHAD.AST.UnMonoid
+ CHAD.AST.UnUser
CHAD.AST.Weaken
CHAD.AST.Weaken.Auto
CHAD.Compile
diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs
index 73d1580..ef9b685 100644
--- a/src/CHAD/APIv1.hs
+++ b/src/CHAD/APIv1.hs
@@ -117,6 +117,7 @@ jvp term
STI64 -> EVar ext (STScal STI64) (IS IZ)
STBool -> EVar ext (STScal STBool) (IS IZ)
ezipDN STAccum{} = error "jvp: Accumulators not supported in source program"
+ ezipDN STUser{} = error "User types not yet supported in forward AD"
eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t'))
eunzipDN STNil = EPair ext (ENil ext) (ENil ext)
@@ -153,6 +154,7 @@ jvp term
STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext)
STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext)
eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program"
+ eunzipDN STUser{} = error "User types not yet supported in forward AD"
-- | Interpret an expression in a given environment.
interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index ce9eb20..eab9425 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -20,6 +20,7 @@ import Data.Functor.Const
import Data.Functor.Identity
import Data.Int (Int64)
import Data.Kind (Type)
+import Data.Proxy
import CHAD.Array
import CHAD.AST.Accum
@@ -137,6 +138,10 @@ data Expr x env t where
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
+
+ -- user types
+ EUser :: x (TUser t) -> STy (TUser t) -> Expr x env (UserRep t) -> Expr x env (TUser t)
+ EUnUser :: x (UserRep t) -> Expr x env (TUser t) -> Expr x env (UserRep t)
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full
@@ -271,6 +276,9 @@ typeOf = \case
EError _ t _ -> t
+ EUser _ t _ -> t
+ EUnUser _ e | STUser t <- typeOf e -> userRepTy t
+
extOf :: Expr x env t -> x t
extOf = \case
EVar x _ _ -> x
@@ -317,6 +325,8 @@ extOf = \case
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
+ EUser x _ _ -> x
+ EUnUser x _ -> x
mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
mapExt f = runIdentity . travExt (Identity . f)
@@ -368,6 +378,8 @@ travExt f = \case
EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
EError x t s -> EError <$> f x <*> pure t <*> pure s
+ EUser x t e -> EUser <$> f x <*> pure t <*> travExt f e
+ EUnUser x e -> EUnUser <$> f x <*> travExt f e
substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
substInline repl =
@@ -430,8 +442,10 @@ subst' f w = \case
EZero x t e -> EZero x t (subst' f w e)
EDeepZero x t e -> EDeepZero x t (subst' f w e)
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
+ EUser x t e -> EUser x t (subst' f w e)
+ EUnUser x e -> EUnUser x (subst' f w e)
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
-> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
@@ -458,6 +472,7 @@ instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+instance UserPrimalType t => KnownTy (TUser t) where knownTy = STUser Proxy
class KnownMTy t where knownMTy :: SMTy t
instance KnownMTy TNil where knownMTy = SMTNil
@@ -466,6 +481,7 @@ instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
+instance UserMonoidType t => KnownMTy (TUser t) where knownMTy = SMTUser Proxy
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
@@ -480,6 +496,7 @@ styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+styKnown (STUser _) = Dict
smtyKnown :: SMTy t -> Dict (KnownMTy t)
smtyKnown SMTNil = Dict
@@ -488,6 +505,7 @@ smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
+smtyKnown (SMTUser _) = Dict
sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
sscaltyKnown STI32 = Dict
@@ -657,6 +675,7 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t
go SMTMaybe{} _ = ENil ext
go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
go SMTScal{} _ = ENil ext
+ go (SMTUser t) e = euserZeroInfo t (EUnUser ext e)
splitSparsePair
:: -- given a sparsity
diff --git a/src/CHAD/AST.hs-boot b/src/CHAD/AST.hs-boot
new file mode 100644
index 0000000..d1b8a62
--- /dev/null
+++ b/src/CHAD/AST.hs-boot
@@ -0,0 +1,15 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+module CHAD.AST where
+
+import Data.Functor.Const (Const)
+import Data.Kind (Type)
+
+import CHAD.AST.Types.Ty
+
+type role Expr representational nominal nominal
+type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data Expr x env t
+
+type Ex = Expr (Const ())
diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs
index ea74a95..f61f00f 100644
--- a/src/CHAD/AST/Accum.hs
+++ b/src/CHAD/AST/Accum.hs
@@ -76,40 +76,6 @@ acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
-type family ZeroInfo t where
- ZeroInfo TNil = TNil
- ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
- ZeroInfo (TLEither a b) = TNil
- ZeroInfo (TMaybe a) = TNil
- ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
- ZeroInfo (TScal t) = TNil
-
-tZeroInfo :: SMTy t -> STy (ZeroInfo t)
-tZeroInfo SMTNil = STNil
-tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
-tZeroInfo (SMTLEither _ _) = STNil
-tZeroInfo (SMTMaybe _) = STNil
-tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
-tZeroInfo (SMTScal _) = STNil
-
--- | Info needed to create a zero-valued deep accumulator for a monoid type.
--- Should be constructable from a D1.
-type family DeepZeroInfo t where
- DeepZeroInfo TNil = TNil
- DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
- DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
- DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
- DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
- DeepZeroInfo (TScal t) = TNil
-
-tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
-tDeepZeroInfo SMTNil = STNil
-tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
-tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
-tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
-tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
-tDeepZeroInfo (SMTScal _) = STNil
-
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs
index 46173d2..8923e13 100644
--- a/src/CHAD/AST/Count.hs
+++ b/src/CHAD/AST/Count.hs
@@ -880,6 +880,16 @@ occCountX initialS topexpr k = case topexpr of
EError _ t msg ->
k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
+
+ EUser _ t e ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EUser ext t (mke env')
+
+ EUnUser _ e ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EUnUser ext (mke env')
where
s = simplifySubstruc (typeOf topexpr) initialS
diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs
index 8e6b745..73f2dcc 100644
--- a/src/CHAD/AST/Env.hs
+++ b/src/CHAD/AST/Env.hs
@@ -11,7 +11,7 @@ module CHAD.AST.Env where
import Data.Type.Equality
-import CHAD.AST.Sparse
+import CHAD.AST.Sparse.Types
import CHAD.AST.Weaken
import CHAD.Data
import CHAD.Drev.Types
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
index 9ddcb35..d9ac8b2 100644
--- a/src/CHAD/AST/Pretty.hs
+++ b/src/CHAD/AST/Pretty.hs
@@ -374,6 +374,15 @@ ppExpr' d val expr = case expr of
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
+ EUser _ t@STUser{} e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $
+ ppApp (ppString ("user[" ++ show (typeOfProxy t) ++ "]") <> ppX expr) [e']
+
+ EUnUser _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "unuser" <> ppX expr) [e']
+
ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
ppExprLet d val etop = do
let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
@@ -421,6 +430,7 @@ ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++
ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
ppSparse (SMTScal _) SpScal = "."
+ppSparse (SMTUser _) SpUser = "U"
ppCommut :: Commutative -> String
ppCommut Commut = "(C)"
@@ -469,6 +479,7 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF64 -> "f64"
STBool -> "bool"
ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+ppSTy' d (STUser t) = ppParen (d > 10) $ ppString ("User " ++ showsPrec 11 (typeOfProxy t) "")
ppSMTy :: Int -> SMTy t -> String
ppSMTy d ty = render $ ppSMTy' d ty
@@ -485,6 +496,7 @@ ppSMTy' _ (SMTScal sty) = ppString $ case sty of
STI64 -> "i64"
STF32 -> "f32"
STF64 -> "f64"
+ppSMTy' d (SMTUser t) = ppSTy' d (STUser t)
ppString :: String -> Doc x
ppString = fromString
diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs
index 85f2882..30e6b6f 100644
--- a/src/CHAD/AST/Sparse.hs
+++ b/src/CHAD/AST/Sparse.hs
@@ -1,7 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where
@@ -9,8 +11,10 @@ module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) wh
import Data.Type.Equality
import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Env
import CHAD.AST.Sparse.Types
-import CHAD.Data (SBool(..))
+import CHAD.Data
sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
@@ -43,6 +47,7 @@ sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
(EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+sparsePlus (SMTUser t) SpUser e1 e2 = EPlus ext (SMTUser t) e1 e2
cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
@@ -61,6 +66,19 @@ cheapZero (SMTScal t) = case t of
STI64 -> Just (EConst ext t 0)
STF32 -> Just (EConst ext t 0.0)
STF64 -> Just (EConst ext t 0.0)
+cheapZero (SMTUser t) =
+ let zero1 = euserZero t (EVar ext (userZeroInfo t) IZ)
+ occenv1 = occCountAll @_ @'[_] zero1
+ zero2 = euserZero t (euserZeroInfo t (EVar ext (userRepTy t) IZ))
+ occenv2 = occCountAll @_ @'[_] zero2
+ in deleteUnused (userZeroInfo t `SCons` SNil) occenv1 $ \case
+ sub@(SENo SETop) | cheapExpr zero1 ->
+ Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero1)))
+ _ ->
+ deleteUnused (userRepTy t `SCons` SNil) occenv2 $ \case
+ sub@(SENo SETop) | cheapExpr zero2 ->
+ Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero2)))
+ _ -> Nothing
data Injection sp a b where
@@ -294,3 +312,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
-- scalars
sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
+
+-- user types
+sparsePlusS _ _ (SMTUser t) SpUser SpUser k = k SpUser (Inj id) (Inj id) (EPlus ext (SMTUser t))
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 8f41ba4..f97a261 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -20,6 +20,7 @@ data Sparse t t' where
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
SpScal :: Sparse (TScal t) (TScal t)
+ SpUser :: Sparse (TUser t) (TUser t)
deriving instance Show (Sparse t t')
class ApplySparse f where
@@ -33,6 +34,7 @@ instance ApplySparse STy where
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
applySparse SpScal t = t
+ applySparse SpUser t = t
instance ApplySparse SMTy where
applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
@@ -42,6 +44,7 @@ instance ApplySparse SMTy where
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
applySparse SpScal t = t
+ applySparse SpUser t = t
class IsSubType s where
@@ -68,6 +71,7 @@ instance IsSubType Sparse where
subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
subtTrans SpScal SpScal = SpScal
+ subtTrans SpUser SpUser = SpUser
subtFull = spDense
@@ -78,6 +82,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
spDense (SMTMaybe t) = SpMaybe (spDense t)
spDense (SMTArr _ t) = SpArr (spDense t)
spDense (SMTScal _) = SpScal
+spDense (SMTUser _) = SpUser
isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
isDense SMTNil SpAbsent = Just Refl
@@ -96,6 +101,7 @@ isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
isDense (SMTScal _) SpScal = Just Refl
+isDense (SMTUser _) SpUser = Just Refl
isAbsent :: Sparse t t' -> Bool
isAbsent (SpSparse s) = isAbsent s
@@ -105,3 +111,4 @@ isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
isAbsent SpScal = False
+isAbsent SpUser = False
diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs
index 34267e4..75c70ea 100644
--- a/src/CHAD/AST/SplitLets.hs
+++ b/src/CHAD/AST/SplitLets.hs
@@ -79,6 +79,8 @@ splitLets' = \sub -> \case
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
+ EUser x t e -> EUser x t (splitLets' sub e)
+ EUnUser x e -> EUnUser x (splitLets' sub e)
where
sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
-> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
@@ -167,6 +169,7 @@ split typ = case typ of
STArr{} -> other
STScal{} -> other
STAccum{} -> other
+ STUser{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
@@ -186,6 +189,7 @@ splitRec rhs typ = case typ of
STArr{} -> other
STScal{} -> other
STAccum{} -> other
+ STUser{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex env '[t])
other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index f0feb55..bb2fcfa 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -1,35 +1,34 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module CHAD.AST.Types where
+{-# LANGUAGE UndecidableSuperClasses #-}
+module CHAD.AST.Types (
+ module CHAD.AST.Types.Ty,
+ module CHAD.AST.Types,
+) where
import Data.Int (Int32, Int64)
import Data.GADT.Compare
import Data.GADT.Show
import Data.Kind (Type)
+import Data.Proxy
import Data.Type.Equality
+import Type.Reflection
import CHAD.Data
+import CHAD.AST.Types.Ty
+import {-# SOURCE #-} CHAD.AST
-type data Ty
- = TNil
- | TPair Ty Ty
- | TEither Ty Ty
- | TLEither Ty Ty
- | TMaybe Ty
- | TArr Nat Ty -- ^ rank, element type
- | TScal ScalTy
- | TAccum Ty -- ^ contained type must be a monoid type
-
-type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes
-- convenient, but such scalar types are not special in any way.
@@ -43,6 +42,10 @@ data STy t where
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
+ -- the Proxy is here just to provide something that's t-indexed, which is
+ -- damn useful; the UNPACK is to make sure it doesn't actually result in any
+ -- storage here
+ STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t)
deriving instance Show (STy t)
instance GCompare STy where
@@ -62,7 +65,9 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t'))
+ -- STUser{} _ -> GLT ; _ STUser{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
@@ -77,6 +82,7 @@ data SMTy t where
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+ SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t)
deriving instance Show (SMTy t)
instance GCompare SMTy where
@@ -92,7 +98,9 @@ instance GCompare SMTy where
(SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
(SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
- -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+ SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+ (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t'))
+ -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT
instance TestEquality SMTy where testEquality = geq
instance GEq SMTy where geq = defaultGeq
@@ -106,6 +114,7 @@ fromSMTy = \case
SMTMaybe t -> STMaybe (fromSMTy t)
SMTArr n t -> STArr n (fromSMTy t)
SMTScal sty -> STScal sty
+ SMTUser t -> STUser t
data SScalTy t where
STI32 :: SScalTy TI32
@@ -172,7 +181,65 @@ type family ScalIsIntegral t where
ScalIsIntegral TF64 = False
ScalIsIntegral TBool = False
--- | Returns true for arrays /and/ accumulators.
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+ ZeroInfo (TUser t) = UserZeroInfo t
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+tZeroInfo (SMTUser t) = userZeroInfo t
+
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Constructable from a D1; must not have any dynamic sparsity, i.e. must have
+-- the same structure as the corresponding primal.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+ DeepZeroInfo (TUser t) = UserDeepZeroInfo t
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t
+
+class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where
+ type UserRep t :: Ty
+ type UserD2 t :: Type
+ userRepTy :: proxy t -> STy (UserRep t)
+ euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t)))
+ euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t)))
+
+class UserPrimalType t => UserMonoidType t where
+ type UserZeroInfo t :: Ty
+ type UserDeepZeroInfo t :: Ty
+ userZeroInfo :: proxy t -> STy (UserZeroInfo t)
+ userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t)
+ euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t)
+ euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t)
+ euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t)
+ -- | A deep zero must not have any dynamic sparsity.
+ euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t)
+ euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t)
+
+-- | Returns true for arrays /and/ accumulators. Returns False for user types.
typeHasArrays :: STy t' -> Bool
typeHasArrays STNil = False
typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
@@ -182,6 +249,7 @@ typeHasArrays (STMaybe t) = typeHasArrays t
typeHasArrays STArr{} = True
typeHasArrays STScal{} = False
typeHasArrays STAccum{} = True
+typeHasArrays STUser{} = False
typeHasAccums :: STy t' -> Bool
typeHasAccums STNil = False
@@ -192,6 +260,7 @@ typeHasAccums (STMaybe t) = typeHasAccums t
typeHasAccums STArr{} = False
typeHasAccums STScal{} = False
typeHasAccums STAccum{} = True
+typeHasAccums STUser{} = False
type family Tup env where
Tup '[] = TNil
@@ -215,3 +284,6 @@ unTup unpack (_ `SCons` list) tup =
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
+
+typeOfProxy :: Typeable a => proxy a -> TypeRep a
+typeOfProxy _ = typeRep
diff --git a/src/CHAD/AST/Types/Ty.hs b/src/CHAD/AST/Types/Ty.hs
new file mode 100644
index 0000000..cee03be
--- /dev/null
+++ b/src/CHAD/AST/Types/Ty.hs
@@ -0,0 +1,20 @@
+{-# LANGUAGE TypeData #-}
+module CHAD.AST.Types.Ty where
+
+import Data.Kind (Type)
+
+import CHAD.Data (Nat)
+
+
+type data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TLEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty -- ^ contained type must be a monoid type
+ | TUser Type
+
+type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index d3cad25..2166fc6 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -63,6 +63,8 @@ unMonoid = \case
acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
+ EUser _ t e -> EUser ext t (unMonoid e)
+ EUnUser _ e -> EUnUser ext (unMonoid e)
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
-- don't destroy the effects!
@@ -78,6 +80,7 @@ zero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+zero (SMTUser t) e = EUser ext (STUser t) (euserZero t e)
deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
deepZero SMTNil e = elet e $ ENil ext
@@ -99,6 +102,7 @@ deepZero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero (SMTUser t) e = EUser ext (STUser t) (euserDeepZero t e)
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-- don't destroy the effects!
@@ -136,6 +140,7 @@ plus (SMTArr _ t) a b =
ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
+plus (SMTUser t) a b = EUser ext (STUser t) (euserPlus t (EUnUser ext a) (EUnUser ext b))
onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
@@ -183,8 +188,8 @@ accumulateSparse
accumulateSparse topty topsp arg accum = case (topty, topsp) of
(_, s) | Just Refl <- isDense topty s ->
accum WId SAPHere (ENil ext) arg
- (SMTScal _, SpScal) ->
- accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (SMTScal _, SpScal) -> error "TScal is dense"
+ (SMTUser _, SpUser) -> error "TUser is dense"
(_, SpSparse s) ->
emaybe arg
(ENil ext)
diff --git a/src/CHAD/AST/UnUser.hs b/src/CHAD/AST/UnUser.hs
new file mode 100644
index 0000000..73b216c
--- /dev/null
+++ b/src/CHAD/AST/UnUser.hs
@@ -0,0 +1,102 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST.UnUser where
+
+import CHAD.AST
+
+
+type family UnUser t where
+ UnUser TNil = TNil
+ UnUser (TPair a b) = TPair (UnUser a) (UnUser b)
+ UnUser (TEither a b) = TEither (UnUser a) (UnUser b)
+ UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b)
+ UnUser (TMaybe a) = TMaybe (UnUser a)
+ UnUser (TArr n t) = TArr n (UnUser t)
+ UnUser (TScal t) = TScal t
+ UnUser (TAccum t) = TAccum (UnUser t)
+ UnUser (TUser t) = UnUser (UserRep t)
+
+type family UnUserE env where
+ UnUserE '[] = '[]
+ UnUserE (t : ts) = UnUser t : UnUserE ts
+
+unUserTy :: STy t -> STy (UnUser t)
+unUserTy = \case
+ STNil -> STNil
+ STPair a b -> STPair (unUserTy a) (unUserTy b)
+ STEither a b -> STEither (unUserTy a) (unUserTy b)
+ STLEither a b -> STLEither (unUserTy a) (unUserTy b)
+ STMaybe t -> STMaybe (unUserTy t)
+ STArr n t -> STArr n (unUserTy t)
+ STScal t -> STScal t
+ STAccum t -> STAccum (unUserMTy t)
+ STUser t -> unUserTy (userRepTy t)
+
+unUserMTy :: SMTy t -> SMTy (UnUser t)
+unUserMTy = \case
+ SMTNil -> SMTNil
+ SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b)
+ SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b)
+ SMTMaybe t -> SMTMaybe (unUserMTy t)
+ SMTArr n t -> SMTArr n (unUserMTy t)
+ SMTScal t -> SMTScal t
+ SMTUser t -> unUserMTy (userRepTy t)
+
+unUser :: Ex env t -> Ex (UnUserE env) (UnUser t)
+unUser = \case
+ EUser _ _ e -> unUser e
+ EUnUser _ e -> unUser e
+
+ EVar _ t i -> EVar ext t (goIdx i)
+ ELet _ rhs body -> ELet ext (unUser rhs) (unUser body)
+ EPair _ a b -> EPair ext (unUser a) (unUser b)
+ EFst _ e -> EFst ext (unUser e)
+ ESnd _ e -> ESnd ext (unUser e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext (unUserTy t) (unUser e)
+ EInr _ t e -> EInr ext (unUserTy t) (unUser e)
+ ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b)
+ ENothing _ t -> ENothing ext t
+ EJust _ e -> EJust ext (unUser e)
+ EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unUser e)
+ ELInr _ t e -> ELInr ext t (unUser e)
+ ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c)
+ EConstArr _ n t x -> EConstArr ext n t x
+ EBuild _ n a b -> EBuild ext n (unUser a) (unUser b)
+ EMap _ a b -> EMap ext (unUser a) (unUser b)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c)
+ ESum1Inner _ e -> ESum1Inner ext (unUser e)
+ EUnit _ e -> EUnit ext (unUser e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b)
+ EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e)
+ EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e)
+ EReshape _ n a b -> EReshape ext n (unUser a) (unUser b)
+ EZip _ a b -> EZip ext (unUser a) (unUser b)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c)
+ EConst _ t x -> EConst ext t x
+ EIdx0 _ e -> EIdx0 ext (unUser e)
+ EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b)
+ EIdx _ a b -> EIdx ext (unUser a) (unUser b)
+ EShape _ e -> EShape ext (unUser e)
+ EOp _ op e -> EOp ext op (unUser e)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2)
+ ERecompute _ e -> ERecompute ext (unUser e)
+ EWith _ t a b -> EWith ext t (unUser a) (unUser b)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc))
+ EError _ t s -> EError ext t s
+
+ EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
+ where
+ err_monoid = error "unUser: Monoid ops found"
diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs
index 212cc7d..284ab49 100644
--- a/src/CHAD/Analysis/Identity.hs
+++ b/src/CHAD/Analysis/Identity.hs
@@ -34,6 +34,7 @@ data ValId t where
VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
+ VIUser :: ValId (UserRep t) -> ValId (TUser t)
deriving instance Show (ValId t)
instance PrettyX ValId where
@@ -56,6 +57,7 @@ instance PrettyX ValId where
VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
VIScal i -> show i
VIAccum i -> 'C' : show i
+ VIUser a -> 'U' : show a
validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b))
validSplitEither (VIEither (Left v)) = (Just v, Nothing)
@@ -386,6 +388,15 @@ idana env expr = case expr of
res <- genIds t
pure (res, EError res t s)
+ EUser _ t e -> do
+ (v, e') <- idana env e
+ pure (VIUser v, EUser (VIUser v) t e')
+
+ EUnUser _ e -> do
+ (v, e') <- idana env e
+ let VIUser v' = v
+ pure (v', EUnUser v' e')
+
-- | This value might be either of the two arguments; we don't know which.
unify :: ValId t -> ValId t -> IdGen (ValId t)
unify VINil VINil = pure VINil
@@ -412,6 +423,7 @@ unify (VILEither a) (VILEither b) = VILEither <$> unify a b
unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
+unify (VIUser i) (VIUser j) = VIUser <$> unify i j
unifyID :: Int -> Int -> IdGen Int
unifyID i j | i == j = pure i
@@ -426,6 +438,7 @@ genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
genIds STAccum{} = VIAccum <$> genId
+genIds (STUser t) = VIUser <$> genIds (userRepTy t)
shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
shidsToVec SZ _ = pure VNil
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index bfa964b..6dc586c 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -392,6 +392,7 @@ expandSparse (STArr _ t) (SpArr s) epr e =
expandSparse (STScal STF32) SpScal _ e = e
expandSparse (STScal STF64) SpScal _ e = e
expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+expandSparse _ SpUser _ e = e
subenvPlus :: SBool req1 -> SBool req2
-> SList SMTy env
@@ -601,6 +602,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
STF64 -> False
STBool -> True
STAccum{} -> False
+ STUser{} -> False
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
@@ -1378,11 +1380,15 @@ drev des accumMap sd = \case
EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+ EUser _ t@STUser{} _ -> err_user ("EUser " ++ show (typeOfProxy t))
+ EUnUser _ e | t@STUser{} <- typeOf e -> err_user ("EUnUser " ++ show (typeOfProxy t))
+
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program"
+ err_user s = error $ "CHAD: operations on user types must always be provided a custom derivative with ECustom, encountered " ++ s
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs
index 6f25f11..43305e6 100644
--- a/src/CHAD/Drev/Accum.hs
+++ b/src/CHAD/Drev/Accum.hs
@@ -21,6 +21,7 @@ d2zeroInfo STMaybe{} _ = ENil ext
d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2zeroInfo (STUser t) e = euserD2ZeroInfo t (EUnUser ext e)
d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
d2deepZeroInfo STNil _ = ENil ext
@@ -43,6 +44,7 @@ d2deepZeroInfo (STMaybe a) e =
d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+d2deepZeroInfo (STUser t) e = euserD2DeepZeroInfo t (EUnUser ext e)
-- The weakening is necessary because we need to initialise the created
-- accumulators with zeros. Those zeros are deep and need full primals. This
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs
index 367a974..e119de2 100644
--- a/src/CHAD/Drev/Types.hs
+++ b/src/CHAD/Drev/Types.hs
@@ -4,7 +4,8 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Drev.Types where
-import CHAD.AST.Accum
+import Data.Proxy
+
import CHAD.AST.Types
import CHAD.Data
@@ -17,6 +18,7 @@ type family D1 t where
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
+ D1 (TUser t) = TUser t
type family D2 t where
D2 TNil = TNil
@@ -26,6 +28,7 @@ type family D2 t where
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
+ D2 (TUser t) = TUser (UserD2 t)
type family D2s t where
D2s TI32 = TNil
@@ -55,6 +58,7 @@ d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1 (STUser t) = STUser t
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
@@ -74,6 +78,7 @@ d2M (STScal t) = case t of
STF64 -> SMTScal STF64
STBool -> SMTNil
d2M STAccum{} = error "Accumulators not allowed in input program"
+d2M (STUser _) = SMTUser Proxy
d2 :: STy t -> STy (D2 t)
d2 = fromSMTy . d2M
@@ -147,6 +152,7 @@ d1Identity = \case
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
+ STUser{} -> Refl
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs
index 019119c..51403f5 100644
--- a/src/CHAD/Drev/Types/ToTan.hs
+++ b/src/CHAD/Drev/Types/ToTan.hs
@@ -41,3 +41,4 @@ toTan typ primal der = case typ of
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
+ STUser{} -> error "User types not yet supported in forward AD"
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs
index 0ae88ce..933a259 100644
--- a/src/CHAD/ForwardAD.hs
+++ b/src/CHAD/ForwardAD.hs
@@ -57,6 +57,7 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanty STUser{} = error "User types not yet supported in forward AD"
tanenv :: SList STy env -> SList STy (TanE env)
tanenv SNil = SNil
@@ -79,6 +80,7 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+zeroTan STUser{} _ = error "User types not yet supported in forward AD"
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
@@ -97,6 +99,7 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars STUser{} _ = []
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -128,6 +131,7 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+unzipDN STUser{} _ = error "User types not yet supported in forward AD"
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -160,6 +164,7 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+dotprodTan STUser{} _ _ = 0.0
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -198,6 +203,7 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
+dnConst STUser{} = error "User types not yet supported in forward AD"
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -233,6 +239,7 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+dnOnehots ty@STUser{} x = const (zeroTan ty x)
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
diff --git a/src/CHAD/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs
index 540ec2b..4a07a2d 100644
--- a/src/CHAD/ForwardAD/DualNumbers.hs
+++ b/src/CHAD/ForwardAD/DualNumbers.hs
@@ -200,10 +200,14 @@ dfwdDN = \case
EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+
+ EUser{} -> err_user "EUser"
+ EUnUser{} -> err_user "EUnUser"
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program"
+ err_user s = error $ "User types not yet supported in forward AD (" ++ s ++ ")"
deriv_extremum :: ScalIsNumeric t ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
diff --git a/src/CHAD/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs
index 5d5dd9e..6dcef1c 100644
--- a/src/CHAD/ForwardAD/DualNumbers/Types.hs
+++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs
@@ -42,6 +42,7 @@ dn (STScal t) = case t of
STI64 -> STScal STI64
STBool -> STScal STBool
dn STAccum{} = error "Accum in source program"
+dn STUser{} = error "User types not yet supported in forward AD"
dne :: SList STy env -> SList STy (DNE env)
dne SNil = SNil
diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs
index 6410b5b..8aa02d7 100644
--- a/src/CHAD/Interpreter.hs
+++ b/src/CHAD/Interpreter.hs
@@ -227,6 +227,8 @@ interpret'Rec env = \case
b' <- interpret' env b
return $ onehotM p t a' b'
EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
+ EUser _ _ e -> interpret' env e
+ EUnUser _ e -> interpret' env e
interpretOp :: SOp a t -> Rep a -> Rep t
interpretOp op arg = case op of
@@ -267,6 +269,9 @@ zeroM typ zi = case typ of
STI64 -> 0
STF32 -> 0.0
STF64 -> 0.0
+ SMTUser t ->
+ interpretOpen False (userZeroInfo t `SCons` SNil) (Value zi `SCons` SNil)
+ (euserZero t (EVar ext (userZeroInfo t) IZ))
deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
deepZeroM typ zi = case typ of
@@ -280,6 +285,9 @@ deepZeroM typ zi = case typ of
STI64 -> 0
STF32 -> 0.0
STF64 -> 0.0
+ SMTUser t ->
+ interpretOpen False (userDeepZeroInfo t `SCons` SNil) (Value zi `SCons` SNil)
+ (euserDeepZero t (EVar ext (userDeepZeroInfo t) IZ))
addM :: SMTy t -> Rep t -> Rep t -> Rep t
addM typ a b = case typ of
@@ -303,6 +311,9 @@ addM typ a b = case typ of
| sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i))
| otherwise -> error "Plus of inconsistently shaped arrays"
SMTScal sty -> numericIsNum sty $ a + b
+ SMTUser t ->
+ interpretOpen False (userRepTy t `SCons` userRepTy t `SCons` SNil) (Value a `SCons` Value b `SCons` SNil)
+ (euserPlus t (EVar ext (userRepTy t) IZ) (EVar ext (userRepTy t) (IS IZ)))
onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
onehotM SAPHere _ _ val = val
@@ -329,6 +340,7 @@ newAcDense typ val = case typ of
SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
+ SMTUser _ -> newIORef val
onehotArray :: Monad m
=> (Rep (AcIdxS p a) -> m v) -- ^ the "one"
@@ -348,6 +360,7 @@ readAc typ val = case typ of
SMTMaybe t -> traverse (readAc t) =<< readIORef val
SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
+ SMTUser _ -> readIORef val
accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
@@ -408,6 +421,7 @@ accumAddDense typ ref sp val = case (typ, sp) of
forM_ [0 .. arraySize ref - 1] $ \i ->
accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
(SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+ (SMTUser t, SpUser) -> AcM $ atomicModifyIORef' ref (\x -> (addM (SMTUser t) x val, ()))
-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
diff --git a/src/CHAD/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs
index fadc6be..32a1a48 100644
--- a/src/CHAD/Interpreter/Rep.hs
+++ b/src/CHAD/Interpreter/Rep.hs
@@ -27,6 +27,7 @@ type family Rep t where
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
Rep (TAccum t) = RepAc t
+ Rep (TUser t) = Rep (UserRep t)
-- Mutable, represents monoid types t.
type family RepAc t where
@@ -36,6 +37,7 @@ type family RepAc t where
RepAc (TMaybe t) = IORef (Maybe (RepAc t))
RepAc (TArr n t) = Array n (RepAc t)
RepAc (TScal sty) = IORef (ScalRep sty)
+ RepAc (TUser t) = IORef (Rep (UserRep t))
newtype Value t = Value { unValue :: Rep t }
@@ -73,6 +75,8 @@ showValue d (STScal sty) x = case sty of
STI64 -> showsPrec d x
STBool -> showsPrec d x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
+showValue d (STUser t) x =
+ showParen (d > 10) $ showString ("User[" ++ show (typeOfProxy t) ++ "] ") . showValue 11 (userRepTy t) x
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
@@ -100,6 +104,7 @@ rnfRep (STScal t) x = case t of
STF64 -> rnf x
STBool -> rnf x
rnfRep STAccum{} _ = error "Cannot rnf accumulators"
+rnfRep (STUser t) x = rnfRep (userRepTy t) x
instance KnownTy t => NFData (Value t) where
rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs
index ea253d6..bbc2db8 100644
--- a/src/CHAD/Simplify.hs
+++ b/src/CHAD/Simplify.hs
@@ -12,6 +12,8 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=50 #-}
module CHAD.Simplify (
simplifyN, simplifyFix,
SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
@@ -307,6 +309,9 @@ simplify'Rec = \case
EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
+ -- user types
+ EUnUser _ (EUser _ _ e) -> acted $ simplify' e
+
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
ELet _ a b -> [simprec| ELet ext *a *b |]
@@ -361,6 +366,8 @@ simplify'Rec = \case
EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
+ EUser _ t e -> [simprec| EUser ext t *e |]
+ EUnUser _ e -> [simprec| EUnUser ext *e |]
-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
@@ -410,6 +417,8 @@ hasAdds = \case
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
+ EUser _ _ e -> hasAdds e
+ EUnUser _ e -> hasAdds e
checkAccumInScope :: SList STy env -> Bool
checkAccumInScope = \case SNil -> False
@@ -424,6 +433,7 @@ checkAccumInScope = \case SNil -> False
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
+ check (STUser t) = check (userRepTy t)
data OneHotTerm dense env a where
OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a