summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-08 12:37:51 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-08 12:37:51 +0100
commit83692cf41f76272423445c9cbbad65561ee3b50c (patch)
tree49f56f498a68722a7302b4ce0b41402a9b9da9ef /src/AST.hs
parent58d4d0b47f5e609e21132f48b727de37d06b6777 (diff)
WIP custom derivatives
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index f603443..6bae84a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -26,6 +26,7 @@ import AST.Types
import AST.Weaken
import CHAD.Types
import Data
+import ForwardAD.DualNumbers.Types
-- | This index is flipped around from the usual direction: the smallest index
@@ -96,6 +97,14 @@ data Expr x env t where
EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
+ -- custom derivatives
+ ECustom :: x t -> STy a -> STy b
+ -> Expr x '[b, a] t -- ^ regular operation
+ -> Expr x '[DN b, a] (DN t) -- ^ dual-numbers forward derivative
+ -> Expr x '[D2 t, D1 b, D1 a] (D2 b) -- ^ CHAD reverse derivative
+ -> Expr x env a -> Expr x env b
+ -> Expr x env t
+
-- accumulation effect
EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
@@ -202,6 +211,8 @@ typeOf = \case
EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
EOp _ op _ -> opt2 op
+ ECustom _ _ _ e _ _ _ _ -> typeOf e
+
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
EAccum _ _ _ _ -> STNil
@@ -274,6 +285,7 @@ subst' f w = \case
EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
+ ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
EZero t -> EZero t