diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 23:09:00 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 23:09:00 +0100 |
commit | a590b1414baec157da3a1f6c5684b1a3bce8ecaf (patch) | |
tree | 45cd0f5559ee2294c1fb889d21ccba49f615d187 | |
parent | f9906020ef838af0bb6683a3a078e23eac555e54 (diff) |
test: Run gradientByForward with compiled DN fun
-rw-r--r-- | src/Example.hs | 2 | ||||
-rw-r--r-- | src/ForwardAD.hs | 30 | ||||
-rw-r--r-- | test/Main.hs | 130 |
3 files changed, 109 insertions, 53 deletions
diff --git a/src/Example.hs b/src/Example.hs index 6ce542e..e234ff4 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -185,6 +185,6 @@ neuralGo = (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') _ -> undefined - (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 + (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b95385c..e867d66 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -7,12 +7,14 @@ module ForwardAD where import Data.Bifunctor (bimap) +import System.IO.Unsafe -- import Debug.Trace -- import AST.Pretty import Array import AST +import Compile import Data import ForwardAD.DualNumbers import Interpreter @@ -212,11 +214,23 @@ dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) -drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) -drevByFwd env expr input dres = - let outty = typeOf expr - in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $ - dnOnehotEnvs env input $ \dnInput -> - -- trace (showEnv (dne env) dnInput) $ - let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr)) - in dotprodTan outty outtan dres +data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t)) + +makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t +makeFwdADArtifactInterp env expr = + let dexpr = dfwdDN expr + in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False inp dexpr) + +{-# NOINLINE makeFwdADArtifactCompile #-} +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) +makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) + +drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) + +drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd (FwdADArtifact env outty fun) input dres = + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ + let (_, outtan) = unzipDN outty (fun dnInput) + in dotprodTan outty outtan dres diff --git a/test/Main.hs b/test/Main.hs index 2b7e7d8..92dc446 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -36,6 +36,7 @@ import Compile import qualified Example import qualified Example.GMM as Example import ForwardAD +import ForwardAD.DualNumbers import Interpreter import Interpreter.Rep import Language @@ -64,8 +65,28 @@ gradientByCHAD' simplIters env term input = second (second (toTanE env input)) $ gradientByCHAD simplIters env term input -gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) -gradientByForward env term input = drevByFwd env term input 1.0 +gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward art input = drevByFwd art input 1.0 + +extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) +extendDN STNil () = pure () +extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y +extendDN (STEither a _) (Left x) = Left <$> extendDN a x +extendDN (STEither _ b) (Right y) = Right <$> extendDN b y +extendDN (STMaybe _) Nothing = pure Nothing +extendDN (STMaybe t) (Just x) = Just <$> extendDN t x +extendDN (STArr _ t) arr = traverse (extendDN t) arr +extendDN (STScal sty) x = case sty of + STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STI32 -> pure x + STI64 -> pure x + STBool -> pure x +extendDN (STAccum _) _ = error "Accumulators not supported in input program" + +extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env)) +extendDNE SNil SNil = pure SNil +extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val closeIsh' :: Double -> Double -> Double -> Bool closeIsh' h a b = @@ -185,53 +206,74 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) adTestGen :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = - withCompiled expr $ \getprimalfun -> - testProperty name $ property $ do - let env = knownEnv @env - - annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr - dtermChadS = simplifyFix dtermChad0 - dtermChadS20 = simplifyN 20 dtermChad0 - - -- pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) - - input <- forAllWith (showEnv env) envGenerator - - let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) - unpackGrad = unTup vUnpair (d2e env) . Value - - let outPrimalI = interpretOpen False input expr - outPrimal <- liftIO $ getprimalfun >>= ($ input) - diff outPrimal (closeIsh' 1e-8) outPrimalI - - let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 - (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS - gradChad0' = toTanE env input gradChad0 - gradChadS' = toTanE env input gradChadS - scChad = envScalars env gradChad0' - scChadS = envScalars env gradChadS' - gradFwd = gradientByForward knownEnv expr input - scFwd = envScalars env gradFwd - - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) - -- annotate (ppExpr knownEnv expr) - -- annotate (ppExpr env dtermChad0) - -- annotate (ppExpr env dtermChadS) - diff outChadS closeIsh outChad0 - diff outChadS closeIsh outPrimal - diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad - diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd + let env = knownEnv @env in + withCompiled env expr $ \getprimalfun -> + testGroup name + [testProperty "compile primal" $ property $ do + primalfun <- liftIO getprimalfun + input <- forAllWith (showEnv env) envGenerator + let outPrimalI = interpretOpen False input expr + outPrimalC <- liftIO $ primalfun input + diff outPrimalI (closeIsh' 1e-8) outPrimalC + + ,withCompiled (dne env) (dfwdDN expr) $ \getdnfun -> + testProperty "compile fwdAD" $ property $ do + dnfun <- liftIO getdnfun + input <- forAllWith (showEnv env) envGenerator + dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input + let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr) + (outDNC1, outDNC2) <- liftIO $ dnfun dinput + diff outDNI1 (closeIsh' 1e-8) outDNC1 + diff outDNI2 (closeIsh' 1e-8) outDNC2 + + ,withResource (makeFwdADArtifactCompile env expr) (\_ -> pure ()) $ \getfwdartifactC -> + testProperty "chad" $ property $ do + primalfun <- liftIO getprimalfun + fwdartifactC <- liftIO getfwdartifactC + + annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) + + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr + dtermChadS = simplifyFix dtermChad0 + dtermChadS20 = simplifyN 20 dtermChad0 + + -- pack Text for less GC pressure (these values are retained for some reason) + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) + + input <- forAllWith (showEnv env) envGenerator + + let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) + unpackGrad = unTup vUnpair (d2e env) . Value + + outPrimal <- liftIO $ primalfun input + + let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 + (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS + gradChad0' = toTanE env input gradChad0 + gradChadS' = toTanE env input gradChadS + scChad = envScalars env gradChad0' + scChadS = envScalars env gradChadS' + gradFwd = gradientByForward fwdartifactC input + scFwd = envScalars env gradFwd + + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) + -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) + -- annotate (ppExpr knownEnv expr) + -- annotate (ppExpr env dtermChad0) + -- annotate (ppExpr env dtermChadS) + diff outChadS closeIsh outChad0 + diff outChadS closeIsh outPrimal + diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad + diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd + ] + where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs -withCompiled :: KnownEnv env => Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled expr = withResource (compile knownEnv expr) (\_ -> pure ()) +withCompiled :: SList STy env -> Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree +withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) term_build1_sum = fromNamed $ lambda #x $ body $ |