summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 22:58:52 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 22:58:52 +0100
commit34887168c0e2deb549e0e7c77e837ab269d894a2 (patch)
tree8e99e95fbd88569a79a634ed4cc9c787b10d21f1
parent2b1562d33bb9496aa449ef9d52735af0ec61c15c (diff)
Add Custom to Language
-rw-r--r--src/Language.hs10
-rw-r--r--src/Language/AST.hs15
2 files changed, 25 insertions, 0 deletions
diff --git a/src/Language.hs b/src/Language.hs
index 88cb1de..aa55140 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Language (
@@ -13,6 +14,7 @@ module Language (
import Array
import AST
+import CHAD.Types
import Data
import Language.AST
@@ -126,6 +128,14 @@ oper = NEOp
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
+custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t)
+ -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape))
+ -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b))
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 =
+ NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2
+
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
infix 4 .==
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 4194913..0ed4e51 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM
import Array
import AST
+import CHAD.Types
import Data
@@ -60,6 +61,13 @@ data NExpr env t where
NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
NEOp :: SOp a t -> NExpr env a -> NExpr env t
+ -- custom derivatives
+ NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation
+ -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass
+ -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative
+ -> NExpr env a -> NExpr env b
+ -> NExpr env t
+
-- partiality
NEError :: STy a -> String -> NExpr env a
@@ -169,6 +177,13 @@ fromNamedExpr val = \case
NEShape e -> EShape ext (go e)
NEOp op e -> EOp ext op (go e)
+ NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
+ ECustom ext ta tb ttape
+ (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
+ (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
+ (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
+ (go e1) (go e2)
+
NEError t s -> EError t s
NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args