aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/APIv1.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-11 23:56:47 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-11 23:56:47 +0100
commitdc61318a22e3492774ab6f6345c9a369222ef2f6 (patch)
tree6b7b0bd1d194666b3ba6eb5f85e620bf850fee69 /src/CHAD/APIv1.hs
parentcd135319f65f40a554d864b2a878a4ef44043a98 (diff)
User-facing API suggestion
Diffstat (limited to 'src/CHAD/APIv1.hs')
-rw-r--r--src/CHAD/APIv1.hs177
1 files changed, 177 insertions, 0 deletions
diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs
new file mode 100644
index 0000000..4e82130
--- /dev/null
+++ b/src/CHAD/APIv1.hs
@@ -0,0 +1,177 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.APIv1 (
+ -- * Expressions and types
+ Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..),
+
+ -- * Reverse derivatives (Fast CHAD)
+ vjp, vjp',
+ D2, D2E, Tup,
+ CHADConfig(..),
+
+ -- ** Primal type transform
+ -- | The primal type transform only important when working with special
+ -- operations like 'CHAD.Language.custom'.
+ D1,
+
+ -- * Forward derivatives (dual numbers)
+ jvp, jvpDN,
+ Tan, DN, DNE,
+
+ -- * Working with expressions
+ interpret, interpret1,
+ compile, compile1,
+ fullSimplify,
+ SList(..), Value(..), Rep,
+ KnownEnv(..), KnownTy(..),
+) where
+
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.UnMonoid
+import CHAD.Compile qualified as Compile
+import CHAD.Data
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.ForwardAD
+import CHAD.ForwardAD.DualNumbers
+import CHAD.Interpreter qualified as Interpreter
+import CHAD.Simplify
+import CHAD.Interpreter.Rep
+
+
+-- | Compute a reverse derivative: a vector-Jacobian product. The type has been
+-- simplified with the assumption that 'D1' is the identity.
+vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+vjp = vjp' (chcSetAccum defaultConfig)
+
+-- | Same as 'vjp'', but supply CHAD configuration.
+vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+vjp' config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ fullSimplify $
+ unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work
+ chad' config knownEnv (simplifyFix term)
+
+jvpDN :: Ex env t -> Ex (DNE env) (DN t)
+jvpDN = dfwdDN
+
+jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t))
+jvp term
+ | Dict <- styKnown (tanty (knownTy @s))
+ = fullSimplify $
+ elet (ezipDN knownTy) $
+ elet (weakenExpr (WCopy WClosed) (jvpDN term)) $
+ eunzipDN (typeOf term)
+ where
+ ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s')
+ ezipDN STNil = ENil ext
+ ezipDN (STPair a b) =
+ EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ)
+ IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ))
+ IS (IS i) -> EVar ext t' (IS (IS i)))
+ (ezipDN @env a))
+ (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ)
+ IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ))
+ IS (IS i) -> EVar ext t' (IS (IS i)))
+ (ezipDN @env b))
+ ezipDN (STEither a b) =
+ ecase (EVar ext (STEither a b) (IS IZ))
+ (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ))
+ (EInl ext (dn b) (ezipDN a))
+ (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr"))
+ (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ))
+ (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl")
+ (EInr ext (dn a) (ezipDN b)))
+ ezipDN (STLEither a b) =
+ elcase (EVar ext (STLEither a b) (IS IZ))
+ (ELNil ext (dn a) (dn b))
+ (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ))
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN")
+ (ELInl ext (dn b) (ezipDN a))
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr"))
+ (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ))
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN")
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl")
+ (ELInr ext (dn a) (ezipDN b)))
+ ezipDN (STMaybe t) =
+ emaybe (EVar ext (STMaybe t) (IS IZ))
+ (ENothing ext (dn t))
+ (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ))
+ (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN")
+ (EJust ext (ezipDN t)))
+ ezipDN (STArr n t) =
+ ezipWith (ezipDN t)
+ (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ)
+ ezipDN (STScal st) = case st of
+ STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ)
+ STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ)
+ STI32 -> EVar ext (STScal STI32) (IS IZ)
+ STI64 -> EVar ext (STScal STI64) (IS IZ)
+ STBool -> EVar ext (STScal STBool) (IS IZ)
+ ezipDN STAccum{} = error "jvp: Accumulators not supported in source program"
+
+ eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t'))
+ eunzipDN STNil = EPair ext (ENil ext) (ENil ext)
+ eunzipDN (STPair a b) =
+ eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 ->
+ eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 ->
+ EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2)
+ eunzipDN (STEither a b) =
+ ecase (EVar ext (STEither (dn a) (dn b)) IZ)
+ (eunPair (eunzipDN a) $ \_ a1 a2 ->
+ EPair ext (EInl ext b a1) (EInl ext (tanty b) a2))
+ (eunPair (eunzipDN b) $ \_ b1 b2 ->
+ EPair ext (EInr ext a b1) (EInr ext (tanty a) b2))
+ eunzipDN (STLEither a b) =
+ elcase (EVar ext (STLEither (dn a) (dn b)) IZ)
+ (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b)))
+ (eunPair (eunzipDN a) $ \_ a1 a2 ->
+ EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2))
+ (eunPair (eunzipDN b) $ \_ b1 b2 ->
+ EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2))
+ eunzipDN (STMaybe t) =
+ emaybe (EVar ext (STMaybe (dn t)) IZ)
+ (EPair ext (ENothing ext t) (ENothing ext (tanty t)))
+ (eunPair (eunzipDN t) $ \_ e1 e2 ->
+ EPair ext (EJust ext e1) (EJust ext e2))
+ eunzipDN (STArr n t) =
+ elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $
+ EPair ext (emap (EFst ext (evar IZ)) (evar IZ))
+ (emap (ESnd ext (evar IZ)) (evar IZ))
+ eunzipDN (STScal st) = case st of
+ STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ
+ STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ
+ STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext)
+ STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext)
+ STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext)
+ eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program"
+
+-- | Interpret an expression in a given environment.
+interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t
+interpret = Interpreter.interpretOpen False knownEnv
+
+-- | Special case of 'interpret' for an expression with a single free variable.
+interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t
+interpret1 x = interpret (Value x `SCons` SNil)
+
+-- | Compile an expression to C, load the resulting shared object into the
+-- program and wrap it in a Haskell function.
+compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t))
+compile = Compile.compileStderr knownEnv
+
+-- | Special case of 'compile' for an expression with a single free variable.
+compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t))
+compile1 term = do
+ f <- Compile.compileStderr knownEnv term
+ return (\x -> f (Value x `SCons` SNil))
+
+-- | Simpify an expression. The 'vjp'/'jvp' functions already do this automatically.
+fullSimplify :: KnownEnv env => Ex env t -> Ex env t
+fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix