aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-11 23:56:47 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-11 23:56:47 +0100
commitdc61318a22e3492774ab6f6345c9a369222ef2f6 (patch)
tree6b7b0bd1d194666b3ba6eb5f85e620bf850fee69 /src/CHAD
parentcd135319f65f40a554d864b2a878a4ef44043a98 (diff)
User-facing API suggestion
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/APIv1.hs177
-rw-r--r--src/CHAD/AST.hs3
-rw-r--r--src/CHAD/AST/Types.hs2
-rw-r--r--src/CHAD/Language.hs176
-rw-r--r--src/CHAD/Language/AST.hs62
5 files changed, 407 insertions, 13 deletions
diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs
new file mode 100644
index 0000000..4e82130
--- /dev/null
+++ b/src/CHAD/APIv1.hs
@@ -0,0 +1,177 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.APIv1 (
+ -- * Expressions and types
+ Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..),
+
+ -- * Reverse derivatives (Fast CHAD)
+ vjp, vjp',
+ D2, D2E, Tup,
+ CHADConfig(..),
+
+ -- ** Primal type transform
+ -- | The primal type transform only important when working with special
+ -- operations like 'CHAD.Language.custom'.
+ D1,
+
+ -- * Forward derivatives (dual numbers)
+ jvp, jvpDN,
+ Tan, DN, DNE,
+
+ -- * Working with expressions
+ interpret, interpret1,
+ compile, compile1,
+ fullSimplify,
+ SList(..), Value(..), Rep,
+ KnownEnv(..), KnownTy(..),
+) where
+
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.UnMonoid
+import CHAD.Compile qualified as Compile
+import CHAD.Data
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.ForwardAD
+import CHAD.ForwardAD.DualNumbers
+import CHAD.Interpreter qualified as Interpreter
+import CHAD.Simplify
+import CHAD.Interpreter.Rep
+
+
+-- | Compute a reverse derivative: a vector-Jacobian product. The type has been
+-- simplified with the assumption that 'D1' is the identity.
+vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+vjp = vjp' (chcSetAccum defaultConfig)
+
+-- | Same as 'vjp'', but supply CHAD configuration.
+vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+vjp' config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ fullSimplify $
+ unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work
+ chad' config knownEnv (simplifyFix term)
+
+jvpDN :: Ex env t -> Ex (DNE env) (DN t)
+jvpDN = dfwdDN
+
+jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t))
+jvp term
+ | Dict <- styKnown (tanty (knownTy @s))
+ = fullSimplify $
+ elet (ezipDN knownTy) $
+ elet (weakenExpr (WCopy WClosed) (jvpDN term)) $
+ eunzipDN (typeOf term)
+ where
+ ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s')
+ ezipDN STNil = ENil ext
+ ezipDN (STPair a b) =
+ EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ)
+ IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ))
+ IS (IS i) -> EVar ext t' (IS (IS i)))
+ (ezipDN @env a))
+ (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ)
+ IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ))
+ IS (IS i) -> EVar ext t' (IS (IS i)))
+ (ezipDN @env b))
+ ezipDN (STEither a b) =
+ ecase (EVar ext (STEither a b) (IS IZ))
+ (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ))
+ (EInl ext (dn b) (ezipDN a))
+ (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr"))
+ (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ))
+ (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl")
+ (EInr ext (dn a) (ezipDN b)))
+ ezipDN (STLEither a b) =
+ elcase (EVar ext (STLEither a b) (IS IZ))
+ (ELNil ext (dn a) (dn b))
+ (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ))
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN")
+ (ELInl ext (dn b) (ezipDN a))
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr"))
+ (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ))
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN")
+ (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl")
+ (ELInr ext (dn a) (ezipDN b)))
+ ezipDN (STMaybe t) =
+ emaybe (EVar ext (STMaybe t) (IS IZ))
+ (ENothing ext (dn t))
+ (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ))
+ (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN")
+ (EJust ext (ezipDN t)))
+ ezipDN (STArr n t) =
+ ezipWith (ezipDN t)
+ (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ)
+ ezipDN (STScal st) = case st of
+ STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ)
+ STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ)
+ STI32 -> EVar ext (STScal STI32) (IS IZ)
+ STI64 -> EVar ext (STScal STI64) (IS IZ)
+ STBool -> EVar ext (STScal STBool) (IS IZ)
+ ezipDN STAccum{} = error "jvp: Accumulators not supported in source program"
+
+ eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t'))
+ eunzipDN STNil = EPair ext (ENil ext) (ENil ext)
+ eunzipDN (STPair a b) =
+ eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 ->
+ eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 ->
+ EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2)
+ eunzipDN (STEither a b) =
+ ecase (EVar ext (STEither (dn a) (dn b)) IZ)
+ (eunPair (eunzipDN a) $ \_ a1 a2 ->
+ EPair ext (EInl ext b a1) (EInl ext (tanty b) a2))
+ (eunPair (eunzipDN b) $ \_ b1 b2 ->
+ EPair ext (EInr ext a b1) (EInr ext (tanty a) b2))
+ eunzipDN (STLEither a b) =
+ elcase (EVar ext (STLEither (dn a) (dn b)) IZ)
+ (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b)))
+ (eunPair (eunzipDN a) $ \_ a1 a2 ->
+ EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2))
+ (eunPair (eunzipDN b) $ \_ b1 b2 ->
+ EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2))
+ eunzipDN (STMaybe t) =
+ emaybe (EVar ext (STMaybe (dn t)) IZ)
+ (EPair ext (ENothing ext t) (ENothing ext (tanty t)))
+ (eunPair (eunzipDN t) $ \_ e1 e2 ->
+ EPair ext (EJust ext e1) (EJust ext e2))
+ eunzipDN (STArr n t) =
+ elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $
+ EPair ext (emap (EFst ext (evar IZ)) (evar IZ))
+ (emap (ESnd ext (evar IZ)) (evar IZ))
+ eunzipDN (STScal st) = case st of
+ STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ
+ STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ
+ STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext)
+ STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext)
+ STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext)
+ eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program"
+
+-- | Interpret an expression in a given environment.
+interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t
+interpret = Interpreter.interpretOpen False knownEnv
+
+-- | Special case of 'interpret' for an expression with a single free variable.
+interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t
+interpret1 x = interpret (Value x `SCons` SNil)
+
+-- | Compile an expression to C, load the resulting shared object into the
+-- program and wrap it in a Haskell function.
+compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t))
+compile = Compile.compileStderr knownEnv
+
+-- | Special case of 'compile' for an expression with a single free variable.
+compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t))
+compile1 term = do
+ f <- Compile.compileStderr knownEnv term
+ return (\x -> f (Value x `SCons` SNil))
+
+-- | Simpify an expression. The 'vjp'/'jvp' functions already do this automatically.
+fullSimplify :: KnownEnv env => Ex env t -> Ex env t
+fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index 2f4b5c2..be7f95e 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -139,6 +139,9 @@ data Expr x env t where
EError :: x a -> STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full
+-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@);
+-- 'Ex' sets this to nothing.
type Ex = Expr (Const ())
ext :: Const () a
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index 059077d..f0feb55 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -31,6 +31,8 @@ type data Ty
type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes
+-- convenient, but such scalar types are not special in any way.
type STy :: Ty -> Type
data STy t where
STNil :: STy TNil
diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs
index ef89284..6621eef 100644
--- a/src/CHAD/Language.hs
+++ b/src/CHAD/Language.hs
@@ -6,12 +6,65 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module CHAD.Language (
+ -- * Named expressions
fromNamed,
- NExpr,
- Ex,
- module CHAD.Language,
- module CHAD.AST.Types,
+ NExpr, NFun,
+
+ -- * Functions
+ lambda,
+ body,
+ inline,
+ (.$),
+
+ -- * Basic language constructs
+ let_,
+ pair, fst_, snd_, nil,
+ inl, inr, case_,
+ nothing, just, maybe_,
+
+ -- * Array operations
+ constArr_,
+ build1, build2, build,
+ map_,
+ fold1i, fold1i',
+ sum1i,
+ unit,
+ replicate1i,
+ maximum1i, minimum1i,
+ reshape,
+ fold1iD1, fold1iD1',
+ fold1iD2,
+
+ -- * Scalar operations
+ -- | Note that 'NExpr' is also an instance of some numeric classes like 'Num' and 'Floating'.
+ const_,
+ idx0,
+ (!),
+ shape,
+ length_,
+ error_,
+ (.==), (.<), (CHAD.Language..>), (.<=), (.>=),
+ not_, and_, or_,
+ mod_, round_, toFloat_, idiv,
+
+ -- * Control flow
+ if_,
+
+ -- * Special operations
+ custom,
+ recompute,
+ with, accum, accumS,
+ oper, oper2,
+
+ -- * Helper types
+ (:->)(..),
+
+ -- * Reexports
+ TIx,
Lookup,
+ Ex,
+ Ty(..),
+ SNat(..), Nat(..), N0, N1, N2, N3,
) where
import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
@@ -19,34 +72,56 @@ import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
import CHAD.Array
import CHAD.AST
import CHAD.AST.Sparse.Types
-import CHAD.AST.Types
import CHAD.Data
import CHAD.Drev.Types
import CHAD.Language.AST
+-- | Helper type, used for e.g. 'case_' and 'build'.
data a :-> b = a :-> b
deriving (Show)
infixr 0 :->
+-- | See 'fromNamed' for a usage example.
body :: NExpr env t -> NFun env env t
body = NBody
+-- | See 'fromNamed' for a usage example.
lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam
+-- | Inline a function here, with the given list of expressions as arguments.
+-- While this is a normal 'SList', the @params@ list is reversed from the
+-- natural argument order of the function; the '(.$)' helper operator serves to
+-- "fix" the order.
+--
+-- @
+-- let fun = 'lambda' \@(TScal TF64) #x $ 'lambda' \@(TScal TBool) #b $ 'body' $ if_ #b #x (#x + 1)
+-- in 'inline' fun ('SNil' .$ 16 .$ 'const_' True)
+-- @
+--
+-- Note that no 'const_' is needed for the @16@, because 'NExpr' implements
+-- 'Num'.
inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t
inline = inlineNFun
--- To be used to construct the argument list for 'inline'.
---
--- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y
--- > in inline fun (SNil .$ 16 .$ 26)
+-- | Helper for constructing the argument list for 'inline';
+-- @(.$) = flip 'SCons'@. See 'inline'.
(.$) :: SList f list -> f a -> SList f (a : list)
(.$) = flip SCons
+-- | The first 'Var' argument is the left-hand side of this let-binding. For example:
+--
+-- @
+-- 'fromNamed' $ 'lambda' \@(TScal TI64) #a $ 'body' $
+-- 'let_' #x (#a + 1) $
+-- #x * #a
+-- @
+--
+-- This produces an expression of type @'Ex' '[TScal TI64] (TScal TI64)@ that
+-- corresponds to the Haskell code @\\a -> let x = a + 1 in x * a@.
let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
let_ = NELet
@@ -68,6 +143,14 @@ inl = NEInl knownTy
inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b)
inr = NEInr knownTy
+-- | A @case@ expression on @Either@s. For example, the following expression
+-- will evaluate to 10 + 1 = 11:
+--
+-- @
+-- 'case_' ('inl' 10)
+-- (#x :-> #x + 1)
+-- (#y :-> #y * 2)
+-- @
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
@@ -77,18 +160,33 @@ nothing = NENothing knownTy
just :: NExpr env a -> NExpr env (TMaybe a)
just = NEJust
+-- | Analogue of the 'Prelude.maybe' function in the Haskell Prelude:
+--
+-- @
+-- 'maybe_' 2 (#x :-> #x * 3) (...)
+-- @
+--
+-- will return 2 if @(...)@ is @Nothing@ and @x + 3@ if it is @Just x@.
maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b
maybe_ a (v :-> b) c = NEMaybe a v b c
+-- | To construct 'Array' values, see "CHAD.Array".
constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
constArr_ x =
let ty = knownScalTy
in case scalRepIsShow ty of
Dict -> NEConstArr knownNat ty x
+-- | Special case of 'build' for 1-dimensional arrays. This produces the array
+-- [0.0, 1.0, 2.0]:
+--
+-- @
+-- 'build1' 3 (#i :-> 'toFloat_' #i)
+-- @
build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
+-- | Special case of 'build' for 2-dimensional arrays.
build2 :: NExpr env TIx -> NExpr env TIx
-> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
-> NExpr env (TArr (S (S Z)) t)
@@ -100,6 +198,15 @@ build2 a1 a2 (v1 :-> v2 :-> b) =
let_ v2 (NEDrop SZ (snd_ #idx)) $
NEDrop (SS (SS SZ)) b)
+-- | General n-dimensional elementwise array constructor. A 3-dimensional index
+-- looks like @((((), i1), i2), i3)@; other dimensionalities are analogous. The
+-- innermost dimension (i.e. whose index variable varies the fastest in the
+-- standard memory layout) is the right-most index, i.e. @i3@ in 3D example. To
+-- create a 10-by-10 table of (row, column) pairs:
+--
+-- @
+-- 'build' ('SS' ('SS' 'SZ')) ('pair' ('pair' 'nil' 10) 10) (#i :-> #j :-> 'pair' #i #j)
+-- @
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
@@ -108,6 +215,7 @@ map_ :: forall n a b env name. (KnownNat n, KnownTy a)
-> NExpr env (TArr n a) -> NExpr env (TArr n b)
map_ (v :-> a) b = NEMap v a b
+-- | Fold over the innermost dimension of an array, thus reducing its dimensionality by one.
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
@@ -120,6 +228,10 @@ fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
NEDrop (SS (SS SZ)) e1)
e2 e3
+-- | The underlying AST constructor for a fold takes a function with /one/
+-- argument: a pair of inputs. 'fold1i'' directly returns this AST constructor
+-- in case it is helpful for testing. The 'fold1i' function is a convenience
+-- wrapper around 'fold1i''.
fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3
@@ -141,6 +253,7 @@ minimum1i e = NEMinimum1Inner e
reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
reshape = NEReshape
+-- | 'fold1iD1'' with a curried combination function.
fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
-> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
@@ -154,10 +267,12 @@ fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
NEDrop (SS (SS SZ)) e1)
e2 e3
+-- | Primal of a fold. Not supported in the input program for reverse differentiation.
fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b))
-> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3
+-- | Reverse pass of a fold. Not supported in the input program for reverse differentiation.
fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
-> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
@@ -175,6 +290,9 @@ idx0 = NEIdx0
-- (.!) = NEIdx1
-- infixl 9 .!
+-- | Index an array. Note that the index is a tuple, just like the argument to
+-- the function in 'build'. To index a 2-dimensional array @a@ at row @i@ and
+-- column @j@, write @a '!' 'pair' ('pair' 'nil' i) j@.
(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
(!) = NEIdx
infixl 9 !
@@ -182,6 +300,7 @@ infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
+-- | Convenience special case of 'shape' for single-dimensional arrays.
length_ :: NExpr env (TArr N1 t) -> NExpr env TIx
length_ e = snd_ (shape e)
@@ -194,6 +313,30 @@ oper2 op a b = NEOp op (pair a b)
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
+-- | Specify a custom reverse derivative for a subexpression. Morally, the type
+-- of this combinator should be read as follows:
+--
+-- @
+-- custom :: (a -> b -> t) -- normal semantics
+-- -> (D1 a -> D1 b -> (D1 t, tape)) -- forward pass
+-- -> (tape -> D2 t -> D2 b) -- reverse pass
+-- -> a -> b -- arguments
+-- -> t -- result
+-- @
+--
+-- In normal evaluation, or when forward-differentiating, the first argument is
+-- taken and the second and third are ignored. When reverse-differentiating
+-- using CHAD, however, the /first/ argument is ignored and the second and
+-- third arguments are respectively put in the forward and the reverse passes
+-- of the derivative program. The @tape@ value may be used to remember primals
+-- for the reverse pass.
+--
+-- This combinator allows for "inactive" and "active" inputs to the operation;
+-- derivatives to the "inactive" input are not propagated. The active input
+-- (whose derivatives /are/ propagated) has type @b@; the inactive input has
+-- type @a@.
+--
+-- No accumulators are allowed inside @a@, @b@ and @tape@.
custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
-> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
-> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
@@ -202,15 +345,30 @@ custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+-- | Semantically the identity, but when reverse differentiating using CHAD,
+-- the contained expression is recomputed in the reverse pass. This is a
+-- light-weight form of checkpointing, with the goal of reducing the number
+-- primal values being stored and thus reducing memory use and memory traffic.
+--
+-- Note that free variables of the contained expression do still need to be
+-- stored, as we do need to be able to recompute the expression in the reverse
+-- pass.
recompute :: NExpr env a -> NExpr env a
recompute = NERecompute
+-- | Introduce an accumulator. The initial value is not allowed to be sparse!
+-- See 'CHAD.AST.EWith'. Not supported in the input program for reverse
+-- differentiation.
with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
with a (n :-> b) = NEWith (knownMTy @t) a n b
+-- | Accumulate to an accumulator. Not supported in the input program for
+-- reverse differentiation.
accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+-- | Accumulate to an accumulator with additional sparsity. Not supported in
+-- the input program for reverse differentiation.
accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
accumS p a sp b c = NEAccum knownMTy p a sp b c
diff --git a/src/CHAD/Language/AST.hs b/src/CHAD/Language/AST.hs
index b270844..502a2b3 100644
--- a/src/CHAD/Language/AST.hs
+++ b/src/CHAD/Language/AST.hs
@@ -28,6 +28,8 @@ import CHAD.Data
import CHAD.Drev.Types
+-- | A named expression: variables have names, not De Bruijn indices.
+-- Otherwise essentially identical to 'Expr'.
type NExpr :: [(Symbol, Ty)] -> Ty -> Type
data NExpr env t where
-- lambda calculus
@@ -99,7 +101,14 @@ data NExpr env t where
NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)
+-- | Look up the type of a name in a named environment.
type Lookup name env = Lookup1 (name == "_") name env
+-- | This curious stack of type families is used instead of normal pattern
+-- matching so the decidable boolean predicate "==" is used. This means that
+-- introducing evidence of @(name1 == name2) ~ False@ may allow a certain
+-- lookup to reduce even if the names in question are not statically known.
+-- This flexibility is used with e.g. 'assertSymbolDistinct' in
+-- 'CHAD.Language.fold1i'.
type family Lookup1 eqblank name env where
Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
Lookup1 False name env = Lookup2 name env
@@ -160,10 +169,20 @@ data NEnv env where
NTop :: NEnv '[]
NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)
--- | First (outermost) parameter on the outside, on the left.
--- * env: environment of this function (grows as you go deeper inside lambdas)
--- * env': environment of the body of the function
--- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
+-- | A named /function/. These can be used in only two ways: they can be
+-- converted to an unnamed 'Expr' using 'fromNamed', and they can be inlined
+-- using 'CHAD.Language.inline'.
+--
+-- * @env@: environment of this function (smaller than @env'@; grows as you descend under lambdas)
+-- * @env'@: environment of the body of the function
+--
+-- For example, a function @(\\(x :: a) (y :: b) -> _ :: c)@ with two free
+-- variables, @u :: t1@ and @v :: t2@, would be represented with a value of the
+-- following type:
+--
+-- @
+-- NFun '['("v", t2), '("u", t1)] '['("y", b), '("x", a), '("v", t2), '("u", t1)] c
+-- @
data NFun env env' t where
NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
NBody :: NExpr env' t -> NFun env' env' t
@@ -179,6 +198,41 @@ envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env
inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
inlineNFun fun args = NEUnnamed (fromNamed fun) args
+-- | Convert a named function to an unnamed expression with free variables,
+-- ready for consumption by the rest of this library. The function must be
+-- closed (meaning that the function as a whole cannot have free variables),
+-- and the arguments of the function are realised as free variables of the
+-- resulting expression. Typical usage looks as follows:
+--
+-- @
+-- {-# LANGUAGE OverloadedLabels #-}
+-- import CHAD.Language
+-- 'fromNamed' $ 'CHAD.Language.lambda' \@(TScal TF64) #x $ 'CHAD.Language.lambda' \@(TScal TI64) #i $ 'CHAD.Language.body' $ #x + 'CHAD.Language.toFloat_' #i
+-- :: 'Ex' '[TScal TI64, TScal TF64] (TScal TF64)
+-- @
+--
+-- The rest of the library generally considers expressions with free variables
+-- to stand in for "functions", by considering the free variables as the
+-- function's inputs.
+--
+-- Note that while environments normally grow to the right (e.g. in type theory
+-- notation), as they as type-level lists here, they grow to the /left/. This
+-- is why the second (innermost) argument of the example, @i@, ends up at the
+-- head of the environment of the constructed expression.
+--
+-- __Type applications__: The type applications to 'CHAD.Language.lambda' above
+-- are good practice, but not always necessary; if GHC can infer the type of
+-- the argument from the body of the expression, the type application is
+-- unnecessary.
+--
+-- __Variables__: The major element of syntactic sugar in this module is using
+-- OverloadedLabels for variable names. Variables are represented in 'NExpr'
+-- (and thus 'NFun') using the 'Var' type; you should never have to manually
+-- construct a 'Var'. Instead, 'Var' implements 'IsLabel' and as such can be
+-- produced with the syntax @#name@, where "name" is the name of the variable.
+-- This syntax produces a polymorphic variable reference whose (embedded) type
+-- is left to GHC's type inference engine using a 'KnownTy' constraint. See
+-- also 'CHAD.Language.let_'.
fromNamed :: NFun '[] env t -> Ex (UnName env) t
fromNamed = fromNamedFun NTop