diff options
Diffstat (limited to 'src/CHAD/APIv1.hs')
| -rw-r--r-- | src/CHAD/APIv1.hs | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs new file mode 100644 index 0000000..4e82130 --- /dev/null +++ b/src/CHAD/APIv1.hs @@ -0,0 +1,177 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.APIv1 ( + -- * Expressions and types + Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..), + + -- * Reverse derivatives (Fast CHAD) + vjp, vjp', + D2, D2E, Tup, + CHADConfig(..), + + -- ** Primal type transform + -- | The primal type transform only important when working with special + -- operations like 'CHAD.Language.custom'. + D1, + + -- * Forward derivatives (dual numbers) + jvp, jvpDN, + Tan, DN, DNE, + + -- * Working with expressions + interpret, interpret1, + compile, compile1, + fullSimplify, + SList(..), Value(..), Rep, + KnownEnv(..), KnownTy(..), +) where + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Compile qualified as Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter qualified as Interpreter +import CHAD.Simplify +import CHAD.Interpreter.Rep + + +-- | Compute a reverse derivative: a vector-Jacobian product. The type has been +-- simplified with the assumption that 'D1' is the identity. +vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp = vjp' (chcSetAccum defaultConfig) + +-- | Same as 'vjp'', but supply CHAD configuration. +vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp' config term + | Dict <- styKnown (d2 (typeOf term)) = + fullSimplify $ + unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work + chad' config knownEnv (simplifyFix term) + +jvpDN :: Ex env t -> Ex (DNE env) (DN t) +jvpDN = dfwdDN + +jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t)) +jvp term + | Dict <- styKnown (tanty (knownTy @s)) + = fullSimplify $ + elet (ezipDN knownTy) $ + elet (weakenExpr (WCopy WClosed) (jvpDN term)) $ + eunzipDN (typeOf term) + where + ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s') + ezipDN STNil = ENil ext + ezipDN (STPair a b) = + EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env a)) + (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env b)) + ezipDN (STEither a b) = + ecase (EVar ext (STEither a b) (IS IZ)) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EInl ext (dn b) (ezipDN a)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr")) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl") + (EInr ext (dn a) (ezipDN b))) + ezipDN (STLEither a b) = + elcase (EVar ext (STLEither a b) (IS IZ)) + (ELNil ext (dn a) (dn b)) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN") + (ELInl ext (dn b) (ezipDN a)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr")) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN") + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl") + (ELInr ext (dn a) (ezipDN b))) + ezipDN (STMaybe t) = + emaybe (EVar ext (STMaybe t) (IS IZ)) + (ENothing ext (dn t)) + (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ)) + (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN") + (EJust ext (ezipDN t))) + ezipDN (STArr n t) = + ezipWith (ezipDN t) + (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ) + ezipDN (STScal st) = case st of + STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ) + STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ) + STI32 -> EVar ext (STScal STI32) (IS IZ) + STI64 -> EVar ext (STScal STI64) (IS IZ) + STBool -> EVar ext (STScal STBool) (IS IZ) + ezipDN STAccum{} = error "jvp: Accumulators not supported in source program" + + eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t')) + eunzipDN STNil = EPair ext (ENil ext) (ENil ext) + eunzipDN (STPair a b) = + eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 -> + eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 -> + EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2) + eunzipDN (STEither a b) = + ecase (EVar ext (STEither (dn a) (dn b)) IZ) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (EInl ext b a1) (EInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (EInr ext a b1) (EInr ext (tanty a) b2)) + eunzipDN (STLEither a b) = + elcase (EVar ext (STLEither (dn a) (dn b)) IZ) + (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b))) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2)) + eunzipDN (STMaybe t) = + emaybe (EVar ext (STMaybe (dn t)) IZ) + (EPair ext (ENothing ext t) (ENothing ext (tanty t))) + (eunPair (eunzipDN t) $ \_ e1 e2 -> + EPair ext (EJust ext e1) (EJust ext e2)) + eunzipDN (STArr n t) = + elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $ + EPair ext (emap (EFst ext (evar IZ)) (evar IZ)) + (emap (ESnd ext (evar IZ)) (evar IZ)) + eunzipDN (STScal st) = case st of + STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ + STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ + STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext) + STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext) + STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext) + eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program" + +-- | Interpret an expression in a given environment. +interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t +interpret = Interpreter.interpretOpen False knownEnv + +-- | Special case of 'interpret' for an expression with a single free variable. +interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t +interpret1 x = interpret (Value x `SCons` SNil) + +-- | Compile an expression to C, load the resulting shared object into the +-- program and wrap it in a Haskell function. +compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t)) +compile = Compile.compileStderr knownEnv + +-- | Special case of 'compile' for an expression with a single free variable. +compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t)) +compile1 term = do + f <- Compile.compileStderr knownEnv term + return (\x -> f (Value x `SCons` SNil)) + +-- | Simpify an expression. The 'vjp'/'jvp' functions already do this automatically. +fullSimplify :: KnownEnv env => Ex env t -> Ex env t +fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix |
