From 34887168c0e2deb549e0e7c77e837ab269d894a2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 9 Nov 2024 22:58:52 +0100 Subject: Add Custom to Language --- src/Language.hs | 10 ++++++++++ src/Language/AST.hs | 15 +++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/Language.hs b/src/Language.hs index 88cb1de..aa55140 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -2,6 +2,7 @@ {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} module Language ( @@ -13,6 +14,7 @@ module Language ( import Array import AST +import CHAD.Types import Data import Language.AST @@ -126,6 +128,14 @@ oper = NEOp error_ :: KnownTy t => String -> NExpr env t error_ s = NEError knownTy s +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .== b = oper (OEq knownScalTy) (pair a b) infix 4 .== diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 4194913..0ed4e51 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import CHAD.Types import Data @@ -60,6 +61,13 @@ data NExpr env t where NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) NEOp :: SOp a t -> NExpr env a -> NExpr env t + -- custom derivatives + NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation + -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative + -> NExpr env a -> NExpr env b + -> NExpr env t + -- partiality NEError :: STy a -> String -> NExpr env a @@ -169,6 +177,13 @@ fromNamedExpr val = \case NEShape e -> EShape ext (go e) NEOp op e -> EOp ext op (go e) + NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> + ECustom ext ta tb ttape + (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) + (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) + (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) + (go e1) (go e2) + NEError t s -> EError t s NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args -- cgit v1.2.3-70-g09d2