diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 | 
| commit | 83692cf41f76272423445c9cbbad65561ee3b50c (patch) | |
| tree | 49f56f498a68722a7302b4ce0b41402a9b9da9ef /src/AST.hs | |
| parent | 58d4d0b47f5e609e21132f48b727de37d06b6777 (diff) | |
WIP custom derivatives
Diffstat (limited to 'src/AST.hs')
| -rw-r--r-- | src/AST.hs | 12 | 
1 files changed, 12 insertions, 0 deletions
| @@ -26,6 +26,7 @@ import AST.Types  import AST.Weaken  import CHAD.Types  import Data +import ForwardAD.DualNumbers.Types  -- | This index is flipped around from the usual direction: the smallest index @@ -96,6 +97,14 @@ data Expr x env t where    EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))    EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t +  -- custom derivatives +  ECustom :: x t -> STy a -> STy b +          -> Expr x '[b, a] t  -- ^ regular operation +          -> Expr x '[DN b, a] (DN t)  -- ^ dual-numbers forward derivative +          -> Expr x '[D2 t, D1 b, D1 a] (D2 b)  -- ^ CHAD reverse derivative +          -> Expr x env a -> Expr x env b +          -> Expr x env t +    -- accumulation effect    EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)    EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil @@ -202,6 +211,8 @@ typeOf = \case    EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)    EOp _ op _ -> opt2 op +  ECustom _ _ _ e _ _ _ _ -> typeOf e +    EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)    EAccum _ _ _ _ -> STNil @@ -274,6 +285,7 @@ subst' f w = \case    EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)    EShape x e -> EShape x (subst' f w e)    EOp x op e -> EOp x op (subst' f w e) +  ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2)    EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)    EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)    EZero t -> EZero t | 
