From a590b1414baec157da3a1f6c5684b1a3bce8ecaf Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 9 Mar 2025 23:09:00 +0100 Subject: test: Run gradientByForward with compiled DN fun --- src/Example.hs | 2 +- src/ForwardAD.hs | 30 ++++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/Example.hs b/src/Example.hs index 6ce542e..e234ff4 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -185,6 +185,6 @@ neuralGo = (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') _ -> undefined - (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 + (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b95385c..e867d66 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -7,12 +7,14 @@ module ForwardAD where import Data.Bifunctor (bimap) +import System.IO.Unsafe -- import Debug.Trace -- import AST.Pretty import Array import AST +import Compile import Data import ForwardAD.DualNumbers import Interpreter @@ -212,11 +214,23 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) -drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) -drevByFwd env expr input dres = - let outty = typeOf expr - in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ - dnOnehotEnvs env input $ \dnInput -> - -- trace (showEnv (dne env) dnInput) $ - let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) - in dotprodTan outty outtan dres +data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t)) + +makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t +makeFwdADArtifactInterp env expr = + let dexpr = dfwdDN expr + in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr) + +{-# NOINLINE makeFwdADArtifactCompile #-} +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) +makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) + +drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) + +drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd (FwdADArtifact env outty fun) input dres = + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ + let (_, outtan) = unzipDN outty (fun dnInput) + in dotprodTan outty outtan dres -- cgit v1.2.3-70-g09d2