summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-09 23:09:00 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-09 23:09:00 +0100
commita590b1414baec157da3a1f6c5684b1a3bce8ecaf (patch)
tree45cd0f5559ee2294c1fb889d21ccba49f615d187
parentf9906020ef838af0bb6683a3a078e23eac555e54 (diff)
test: Run gradientByForward with compiled DN fun
-rw-r--r--src/Example.hs2
-rw-r--r--src/ForwardAD.hs30
-rw-r--r--test/Main.hs130
3 files changed, 109 insertions, 53 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 6ce542e..e234ff4 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -185,6 +185,6 @@ neuralGo =
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
(primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
_ -> undefined
- (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
+ (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b95385c..e867d66 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -7,12 +7,14 @@
module ForwardAD where
import Data.Bifunctor (bimap)
+import System.IO.Unsafe
-- import Debug.Trace
-- import AST.Pretty
import Array
import AST
+import Compile
import Data
import ForwardAD.DualNumbers
import Interpreter
@@ -212,11 +214,23 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
`SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
-drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
-drevByFwd env expr input dres =
- let outty = typeOf expr
- in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $
- dnOnehotEnvs env input $ \dnInput ->
- -- trace (showEnv (dne env) dnInput) $
- let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
- in dotprodTan outty outtan dres
+data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t))
+
+makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t
+makeFwdADArtifactInterp env expr =
+ let dexpr = dfwdDN expr
+ in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr)
+
+{-# NOINLINE makeFwdADArtifactCompile #-}
+makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
+makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr)
+
+drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
+
+drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd (FwdADArtifact env outty fun) input dres =
+ dnOnehotEnvs env input $ \dnInput ->
+ -- trace (showEnv (dne env) dnInput) $
+ let (_, outtan) = unzipDN outty (fun dnInput)
+ in dotprodTan outty outtan dres
diff --git a/test/Main.hs b/test/Main.hs
index 2b7e7d8..92dc446 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -36,6 +36,7 @@ import Compile
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
+import ForwardAD.DualNumbers
import Interpreter
import Interpreter.Rep
import Language
@@ -64,8 +65,28 @@ gradientByCHAD' simplIters env term input =
second (second (toTanE env input)) $
gradientByCHAD simplIters env term input
-gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
-gradientByForward env term input = drevByFwd env term input 1.0
+gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByForward art input = drevByFwd art input 1.0
+
+extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
+extendDN STNil () = pure ()
+extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
+extendDN (STEither a _) (Left x) = Left <$> extendDN a x
+extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
+extendDN (STMaybe _) Nothing = pure Nothing
+extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
+extendDN (STArr _ t) arr = traverse (extendDN t) arr
+extendDN (STScal sty) x = case sty of
+ STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
+ STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
+ STI32 -> pure x
+ STI64 -> pure x
+ STBool -> pure x
+extendDN (STAccum _) _ = error "Accumulators not supported in input program"
+
+extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
+extendDNE SNil SNil = pure SNil
+extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val
closeIsh' :: Double -> Double -> Double -> Bool
closeIsh' h a b =
@@ -185,53 +206,74 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)
adTestGen :: forall env. KnownEnv env
=> TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
- withCompiled expr $ \getprimalfun ->
- testProperty name $ property $ do
- let env = knownEnv @env
-
- annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
-
- let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
- dtermChadS = simplifyFix dtermChad0
- dtermChadS20 = simplifyN 20 dtermChad0
-
- -- pack Text for less GC pressure (these values are retained for some reason)
- diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20))
-
- input <- forAllWith (showEnv env) envGenerator
-
- let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
- unpackGrad = unTup vUnpair (d2e env) . Value
-
- let outPrimalI = interpretOpen False input expr
- outPrimal <- liftIO $ getprimalfun >>= ($ input)
- diff outPrimal (closeIsh' 1e-8) outPrimalI
-
- let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
- gradChad0' = toTanE env input gradChad0
- gradChadS' = toTanE env input gradChadS
- scChad = envScalars env gradChad0'
- scChadS = envScalars env gradChadS'
- gradFwd = gradientByForward knownEnv expr input
- scFwd = envScalars env gradFwd
-
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
- -- annotate (ppExpr knownEnv expr)
- -- annotate (ppExpr env dtermChad0)
- -- annotate (ppExpr env dtermChadS)
- diff outChadS closeIsh outChad0
- diff outChadS closeIsh outPrimal
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ let env = knownEnv @env in
+ withCompiled env expr $ \getprimalfun ->
+ testGroup name
+ [testProperty "compile primal" $ property $ do
+ primalfun <- liftIO getprimalfun
+ input <- forAllWith (showEnv env) envGenerator
+ let outPrimalI = interpretOpen False input expr
+ outPrimalC <- liftIO $ primalfun input
+ diff outPrimalI (closeIsh' 1e-8) outPrimalC
+
+ ,withCompiled (dne env) (dfwdDN expr) $ \getdnfun ->
+ testProperty "compile fwdAD" $ property $ do
+ dnfun <- liftIO getdnfun
+ input <- forAllWith (showEnv env) envGenerator
+ dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
+ let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
+ (outDNC1, outDNC2) <- liftIO $ dnfun dinput
+ diff outDNI1 (closeIsh' 1e-8) outDNC1
+ diff outDNI2 (closeIsh' 1e-8) outDNC2
+
+ ,withResource (makeFwdADArtifactCompile env expr) (\_ -> pure ()) $ \getfwdartifactC ->
+ testProperty "chad" $ property $ do
+ primalfun <- liftIO getprimalfun
+ fwdartifactC <- liftIO getfwdartifactC
+
+ annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
+
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+ dtermChadS = simplifyFix dtermChad0
+ dtermChadS20 = simplifyN 20 dtermChad0
+
+ -- pack Text for less GC pressure (these values are retained for some reason)
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20))
+
+ input <- forAllWith (showEnv env) envGenerator
+
+ let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
+ unpackGrad = unTup vUnpair (d2e env) . Value
+
+ outPrimal <- liftIO $ primalfun input
+
+ let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+ (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+ gradChad0' = toTanE env input gradChad0
+ gradChadS' = toTanE env input gradChadS
+ scChad = envScalars env gradChad0'
+ scChadS = envScalars env gradChadS'
+ gradFwd = gradientByForward fwdartifactC input
+ scFwd = envScalars env gradFwd
+
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
+ -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
+ -- annotate (ppExpr knownEnv expr)
+ -- annotate (ppExpr env dtermChad0)
+ -- annotate (ppExpr env dtermChadS)
+ diff outChadS closeIsh outChad0
+ diff outChadS closeIsh outPrimal
+ diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad
+ diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ ]
+
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
-withCompiled :: KnownEnv env => Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
-withCompiled expr = withResource (compile knownEnv expr) (\_ -> pure ())
+withCompiled :: SList STy env -> Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
+withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
term_build1_sum = fromNamed $ lambda #x $ body $