summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-09 23:09:00 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-09 23:09:00 +0100
commita590b1414baec157da3a1f6c5684b1a3bce8ecaf (patch)
tree45cd0f5559ee2294c1fb889d21ccba49f615d187 /src
parentf9906020ef838af0bb6683a3a078e23eac555e54 (diff)
test: Run gradientByForward with compiled DN fun
Diffstat (limited to 'src')
-rw-r--r--src/Example.hs2
-rw-r--r--src/ForwardAD.hs30
2 files changed, 23 insertions, 9 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 6ce542e..e234ff4 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -185,6 +185,6 @@ neuralGo =
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
(primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
_ -> undefined
- (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
+ (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b95385c..e867d66 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -7,12 +7,14 @@
module ForwardAD where
import Data.Bifunctor (bimap)
+import System.IO.Unsafe
-- import Debug.Trace
-- import AST.Pretty
import Array
import AST
+import Compile
import Data
import ForwardAD.DualNumbers
import Interpreter
@@ -212,11 +214,23 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
`SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
-drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
-drevByFwd env expr input dres =
- let outty = typeOf expr
- in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $
- dnOnehotEnvs env input $ \dnInput ->
- -- trace (showEnv (dne env) dnInput) $
- let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
- in dotprodTan outty outtan dres
+data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))
+
+makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
+makeFwdADArtifactInterp env expr =
+ let dexpr = dfwdDN expr
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+
+{-# NOINLINE makeFwdADArtifactCompile #-}
+makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
+makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr)
+
+drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
+
+drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd (FwdADArtifact env outty fun) input dres =
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
+ let (_, outtan) = unzipDN outty (fun dnInput)
+ in dotprodTan outty outtan dres