summaryrefslogtreecommitdiff
path: root/src/Language/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language/AST.hs')
-rw-r--r--src/Language/AST.hs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 4194913..0ed4e51 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM
import Array
import AST
+import CHAD.Types
import Data
@@ -60,6 +61,13 @@ data NExpr env t where
NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
NEOp :: SOp a t -> NExpr env a -> NExpr env t
+ -- custom derivatives
+ NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation
+ -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+
-- partiality
NEError :: STy a -> String -> NExpr env a
@@ -169,6 +177,13 @@ fromNamedExpr val = \case
NEShape e -> EShape ext (go e)
NEOp op e -> EOp ext op (go e)
+ NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
+ ECustom ext ta tb ttape
+ (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
+ (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
+ (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
+ (go e1) (go e2)
+
NEError t s -> EError t s
NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args