summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-08 20:29:29 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-08 20:29:29 +0100
commit4fcdb7118e0084f192753ea6c70394352a27d5ed (patch)
treec5e91ae438b6f4c3e5075bf591e5fbe28aa5d96b /src/AST.hs
parent83692cf41f76272423445c9cbbad65561ee3b50c (diff)
Custom derivatives
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 6bae84a..08a5bba 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -26,7 +26,6 @@ import AST.Types
import AST.Weaken
import CHAD.Types
import Data
-import ForwardAD.DualNumbers.Types
-- | This index is flipped around from the usual direction: the smallest index
@@ -98,10 +97,14 @@ data Expr x env t where
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-- custom derivatives
- ECustom :: x t -> STy a -> STy b
+ -- 'b' is the part of the input of the operation that derivatives should
+ -- be backpropagated to; 'a' is the inactive part. The dual field of
+ -- ECustom does not allow a derivative to be generated for 'a', and hence
+ -- none is propagated.
+ ECustom :: x t -> STy a -> STy b -> STy tape
-> Expr x '[b, a] t -- ^ regular operation
- -> Expr x '[DN b, a] (DN t) -- ^ dual-numbers forward derivative
- -> Expr x '[D2 t, D1 b, D1 a] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Expr x '[D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
-> Expr x env a -> Expr x env b
-> Expr x env t
@@ -211,7 +214,7 @@ typeOf = \case
EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
EOp _ op _ -> opt2 op
- ECustom _ _ _ e _ _ _ _ -> typeOf e
+ ECustom _ _ _ _ e _ _ _ _ -> typeOf e
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
EAccum _ _ _ _ -> STNil
@@ -285,7 +288,7 @@ subst' f w = \case
EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
- ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
EZero t -> EZero t