diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 22:58:52 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 22:58:52 +0100 |
commit | 34887168c0e2deb549e0e7c77e837ab269d894a2 (patch) | |
tree | 8e99e95fbd88569a79a634ed4cc9c787b10d21f1 | |
parent | 2b1562d33bb9496aa449ef9d52735af0ec61c15c (diff) |
Add Custom to Language
-rw-r--r-- | src/Language.hs | 10 | ||||
-rw-r--r-- | src/Language/AST.hs | 15 |
2 files changed, 25 insertions, 0 deletions
diff --git a/src/Language.hs b/src/Language.hs index 88cb1de..aa55140 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -2,6 +2,7 @@ {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} module Language ( @@ -13,6 +14,7 @@ module Language ( import Array import AST +import CHAD.Types import Data import Language.AST @@ -126,6 +128,14 @@ oper = NEOp error_ :: KnownTy t => String -> NExpr env t error_ s = NEError knownTy s +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .== b = oper (OEq knownScalTy) (pair a b) infix 4 .== diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 4194913..0ed4e51 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import CHAD.Types import Data @@ -60,6 +61,13 @@ data NExpr env t where NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) NEOp :: SOp a t -> NExpr env a -> NExpr env t + -- custom derivatives + NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation + -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative + -> NExpr env a -> NExpr env b + -> NExpr env t + -- partiality NEError :: STy a -> String -> NExpr env a @@ -169,6 +177,13 @@ fromNamedExpr val = \case NEShape e -> EShape ext (go e) NEOp op e -> EOp ext op (go e) + NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> + ECustom ext ta tb ttape + (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) + (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) + (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) + (go e1) (go e2) + NEError t s -> EError t s NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args |