diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 |
commit | 4fcdb7118e0084f192753ea6c70394352a27d5ed (patch) | |
tree | c5e91ae438b6f4c3e5075bf591e5fbe28aa5d96b /src/AST.hs | |
parent | 83692cf41f76272423445c9cbbad65561ee3b50c (diff) |
Custom derivatives
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 15 |
1 files changed, 9 insertions, 6 deletions
@@ -26,7 +26,6 @@ import AST.Types import AST.Weaken import CHAD.Types import Data -import ForwardAD.DualNumbers.Types -- | This index is flipped around from the usual direction: the smallest index @@ -98,10 +97,14 @@ data Expr x env t where EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- custom derivatives - ECustom :: x t -> STy a -> STy b + -- 'b' is the part of the input of the operation that derivatives should + -- be backpropagated to; 'a' is the inactive part. The dual field of + -- ECustom does not allow a derivative to be generated for 'a', and hence + -- none is propagated. + ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x '[b, a] t -- ^ regular operation - -> Expr x '[DN b, a] (DN t) -- ^ dual-numbers forward derivative - -> Expr x '[D2 t, D1 b, D1 a] (D2 b) -- ^ CHAD reverse derivative + -> Expr x '[D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr x '[D2 t, tape] (D2 b) -- ^ CHAD reverse derivative -> Expr x env a -> Expr x env b -> Expr x env t @@ -211,7 +214,7 @@ typeOf = \case EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op - ECustom _ _ _ e _ _ _ _ -> typeOf e + ECustom _ _ _ _ e _ _ _ _ -> typeOf e EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ -> STNil @@ -285,7 +288,7 @@ subst' f w = \case EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) - ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero t -> EZero t |