aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/APIv1.hs7
-rw-r--r--src/CHAD/AST.hs5
-rw-r--r--src/CHAD/ForwardAD.hs2
3 files changed, 10 insertions, 4 deletions
diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs
index 1ba01b1..73d1580 100644
--- a/src/CHAD/APIv1.hs
+++ b/src/CHAD/APIv1.hs
@@ -22,7 +22,7 @@ module CHAD.APIv1 (
-- * Forward derivatives (dual numbers)
jvp, jvpDN,
- Tan, DN, DNE,
+ Tan, TanS, DN, DNS, DNE,
-- * Working with expressions
interpret, interpret1,
@@ -30,6 +30,7 @@ module CHAD.APIv1 (
fullSimplify,
SList(..), Value(..), Rep,
KnownEnv(..), KnownTy(..),
+ SNat(..),
) where
import CHAD.AST
@@ -51,7 +52,7 @@ import CHAD.Interpreter.Rep
vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
vjp = vjp' (chcSetAccum defaultConfig)
--- | Same as 'vjp'', but supply CHAD configuration.
+-- | Same as 'vjp', but supply CHAD configuration.
vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
vjp' config term
| Dict <- styKnown (d2 (typeOf term)) =
@@ -172,6 +173,6 @@ compile1 term = do
f <- Compile.compileStderr knownEnv term
return (\x -> f (Value x `SCons` SNil))
--- | Simpify an expression. The 'vjp'/'jvp' functions already do this automatically.
+-- | Simplify an expression. The 'vjp'/'jvp' functions already do this automatically.
fullSimplify :: KnownEnv env => Ex env t -> Ex env t
fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index be7f95e..ce9eb20 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -142,6 +142,11 @@ deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full
-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@);
-- 'Ex' sets this to nothing.
+--
+-- Construct expressions using the functions in "CHAD.Language".
+--
+-- Use 'CHAD.AST.Pretty.pprintExpr' or 'CHAD.AST.Pretty.ppExpr' to inspect
+-- expressions.
type Ex = Expr (Const ())
ext :: Const () a
diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs
index 0ebc244..0ae88ce 100644
--- a/src/CHAD/ForwardAD.hs
+++ b/src/CHAD/ForwardAD.hs
@@ -22,7 +22,7 @@ import CHAD.Interpreter
import CHAD.Interpreter.Rep
--- | Tangent along a type (coincides with cotangent for these types)
+-- | Tangent along a type (coincides with the cotangent, t'CHAD.Drev.Types.D2', for these types)
type family Tan t where
Tan TNil = TNil
Tan (TPair a b) = TPair (Tan a) (Tan b)