{-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module CHAD.APIv1 ( -- * Expressions and types Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..), -- * Reverse derivatives (Fast CHAD) vjp, vjp', D2, D2E, Tup, CHADConfig(..), -- ** Primal type transform -- | The primal type transform is only important when working with special -- operations like 'CHAD.Language.custom'. D1, -- * Forward derivatives (dual numbers) jvp, jvpDN, Tan, DN, DNE, -- * Working with expressions interpret, interpret1, compile, compile1, fullSimplify, SList(..), Value(..), Rep, KnownEnv(..), KnownTy(..), ) where import CHAD.AST import CHAD.AST.Count import CHAD.AST.UnMonoid import CHAD.Compile qualified as Compile import CHAD.Data import CHAD.Drev.Top import CHAD.Drev.Types import CHAD.ForwardAD import CHAD.ForwardAD.DualNumbers import CHAD.Interpreter qualified as Interpreter import CHAD.Simplify import CHAD.Interpreter.Rep -- | Compute a reverse derivative: a vector-Jacobian product. The type has been -- simplified with the assumption that 'D1' is the identity. vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) vjp = vjp' (chcSetAccum defaultConfig) -- | Same as 'vjp'', but supply CHAD configuration. vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) vjp' config term | Dict <- styKnown (d2 (typeOf term)) = fullSimplify $ unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work chad' config knownEnv (simplifyFix term) jvpDN :: Ex env t -> Ex (DNE env) (DN t) jvpDN = dfwdDN jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t)) jvp term | Dict <- styKnown (tanty (knownTy @s)) = fullSimplify $ elet (ezipDN knownTy) $ elet (weakenExpr (WCopy WClosed) (jvpDN term)) $ eunzipDN (typeOf term) where ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s') ezipDN STNil = ENil ext ezipDN (STPair a b) = EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ) IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ)) IS (IS i) -> EVar ext t' (IS (IS i))) (ezipDN @env a)) (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ) IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ)) IS (IS i) -> EVar ext t' (IS (IS i))) (ezipDN @env b)) ezipDN (STEither a b) = ecase (EVar ext (STEither a b) (IS IZ)) (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) (EInl ext (dn b) (ezipDN a)) (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr")) (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl") (EInr ext (dn a) (ezipDN b))) ezipDN (STLEither a b) = elcase (EVar ext (STLEither a b) (IS IZ)) (ELNil ext (dn a) (dn b)) (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN") (ELInl ext (dn b) (ezipDN a)) (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr")) (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN") (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl") (ELInr ext (dn a) (ezipDN b))) ezipDN (STMaybe t) = emaybe (EVar ext (STMaybe t) (IS IZ)) (ENothing ext (dn t)) (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ)) (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN") (EJust ext (ezipDN t))) ezipDN (STArr n t) = ezipWith (ezipDN t) (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ) ezipDN (STScal st) = case st of STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ) STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ) STI32 -> EVar ext (STScal STI32) (IS IZ) STI64 -> EVar ext (STScal STI64) (IS IZ) STBool -> EVar ext (STScal STBool) (IS IZ) ezipDN STAccum{} = error "jvp: Accumulators not supported in source program" eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t')) eunzipDN STNil = EPair ext (ENil ext) (ENil ext) eunzipDN (STPair a b) = eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 -> eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 -> EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2) eunzipDN (STEither a b) = ecase (EVar ext (STEither (dn a) (dn b)) IZ) (eunPair (eunzipDN a) $ \_ a1 a2 -> EPair ext (EInl ext b a1) (EInl ext (tanty b) a2)) (eunPair (eunzipDN b) $ \_ b1 b2 -> EPair ext (EInr ext a b1) (EInr ext (tanty a) b2)) eunzipDN (STLEither a b) = elcase (EVar ext (STLEither (dn a) (dn b)) IZ) (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b))) (eunPair (eunzipDN a) $ \_ a1 a2 -> EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2)) (eunPair (eunzipDN b) $ \_ b1 b2 -> EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2)) eunzipDN (STMaybe t) = emaybe (EVar ext (STMaybe (dn t)) IZ) (EPair ext (ENothing ext t) (ENothing ext (tanty t))) (eunPair (eunzipDN t) $ \_ e1 e2 -> EPair ext (EJust ext e1) (EJust ext e2)) eunzipDN (STArr n t) = elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $ EPair ext (emap (EFst ext (evar IZ)) (evar IZ)) (emap (ESnd ext (evar IZ)) (evar IZ)) eunzipDN (STScal st) = case st of STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext) STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext) STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext) eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program" -- | Interpret an expression in a given environment. interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t interpret = Interpreter.interpretOpen False knownEnv -- | Special case of 'interpret' for an expression with a single free variable. interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t interpret1 x = interpret (Value x `SCons` SNil) -- | Compile an expression to C, load the resulting shared object into the -- program and wrap it in a Haskell function. compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t)) compile = Compile.compileStderr knownEnv -- | Special case of 'compile' for an expression with a single free variable. compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t)) compile1 term = do f <- Compile.compileStderr knownEnv term return (\x -> f (Value x `SCons` SNil)) -- | Simpify an expression. The 'vjp'/'jvp' functions already do this automatically. fullSimplify :: KnownEnv env => Ex env t -> Ex env t fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix