summaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs30
1 files changed, 22 insertions, 8 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b95385c..e867d66 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -7,12 +7,14 @@
module ForwardAD where
import Data.Bifunctor (bimap)
+import System.IO.Unsafe
-- import Debug.Trace
-- import AST.Pretty
import Array
import AST
+import Compile
import Data
import ForwardAD.DualNumbers
import Interpreter
@@ -212,11 +214,23 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
`SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
-drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
-drevByFwd env expr input dres =
- let outty = typeOf expr
- in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $
- dnOnehotEnvs env input $ \dnInput ->
- -- trace (showEnv (dne env) dnInput) $
- let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
- in dotprodTan outty outtan dres
+data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))
+
+makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
+makeFwdADArtifactInterp env expr =
+ let dexpr = dfwdDN expr
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+
+{-# NOINLINE makeFwdADArtifactCompile #-}
+makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
+makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr)
+
+drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
+
+drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd (FwdADArtifact env outty fun) input dres =
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
+ let (_, outtan) = unzipDN outty (fun dnInput)
+ in dotprodTan outty outtan dres