diff options
Diffstat (limited to 'src/Language')
| -rw-r--r-- | src/Language/AST.hs | 15 | 
1 files changed, 15 insertions, 0 deletions
| diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 4194913..0ed4e51 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM  import Array  import AST +import CHAD.Types  import Data @@ -60,6 +61,13 @@ data NExpr env t where    NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))    NEOp :: SOp a t -> NExpr env a -> NExpr env t +  -- custom derivatives +  NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t  -- ^ regular operation +           -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)  -- ^ CHAD forward pass +           -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)  -- ^ CHAD reverse derivative +           -> NExpr env a -> NExpr env b +           -> NExpr env t +    -- partiality    NEError :: STy a -> String -> NExpr env a @@ -169,6 +177,13 @@ fromNamedExpr val = \case    NEShape e -> EShape ext (go e)    NEOp op e -> EOp ext op (go e) +  NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> +    ECustom ext ta tb ttape +            (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) +            (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) +            (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) +            (go e1) (go e2) +    NEError t s -> EError t s    NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args | 
