diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 23:09:00 +0100 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 23:09:00 +0100 | 
| commit | a590b1414baec157da3a1f6c5684b1a3bce8ecaf (patch) | |
| tree | 45cd0f5559ee2294c1fb889d21ccba49f615d187 /test | |
| parent | f9906020ef838af0bb6683a3a078e23eac555e54 (diff) | |
test: Run gradientByForward with compiled DN fun
Diffstat (limited to 'test')
| -rw-r--r-- | test/Main.hs | 114 | 
1 files changed, 78 insertions, 36 deletions
| diff --git a/test/Main.hs b/test/Main.hs index 2b7e7d8..92dc446 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -36,6 +36,7 @@ import Compile  import qualified Example  import qualified Example.GMM as Example  import ForwardAD +import ForwardAD.DualNumbers  import Interpreter  import Interpreter.Rep  import Language @@ -64,8 +65,28 @@ gradientByCHAD' simplIters env term input =    second (second (toTanE env input)) $      gradientByCHAD simplIters env term input -gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) -gradientByForward env term input = drevByFwd env term input 1.0 +gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward art input = drevByFwd art input 1.0 + +extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) +extendDN STNil () = pure () +extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y +extendDN (STEither a _) (Left x) = Left <$> extendDN a x +extendDN (STEither _ b) (Right y) = Right <$> extendDN b y +extendDN (STMaybe _) Nothing = pure Nothing +extendDN (STMaybe t) (Just x) = Just <$> extendDN t x +extendDN (STArr _ t) arr = traverse (extendDN t) arr +extendDN (STScal sty) x = case sty of +  STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) +  STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) +  STI32 -> pure x +  STI64 -> pure x +  STBool -> pure x +extendDN (STAccum _) _ = error "Accumulators not supported in input program" + +extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env)) +extendDNE SNil SNil = pure SNil +extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val  closeIsh' :: Double -> Double -> Double -> Bool  closeIsh' h a b = @@ -185,53 +206,74 @@ adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl)  adTestGen :: forall env. KnownEnv env            => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree  adTestGen name expr envGenerator = -  withCompiled expr $ \getprimalfun -> -  testProperty name $ property $ do -  let env = knownEnv @env +  let env = knownEnv @env in +  withCompiled env expr $ \getprimalfun -> +  testGroup name +    [testProperty "compile primal" $ property $ do +       primalfun <- liftIO getprimalfun +       input <- forAllWith (showEnv env) envGenerator +       let outPrimalI = interpretOpen False input expr +       outPrimalC <- liftIO $ primalfun input +       diff outPrimalI (closeIsh' 1e-8) outPrimalC + +    ,withCompiled (dne env) (dfwdDN expr) $ \getdnfun -> +     testProperty "compile fwdAD" $ property $ do +       dnfun <- liftIO getdnfun +       input <- forAllWith (showEnv env) envGenerator +       dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input +       let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr) +       (outDNC1, outDNC2) <- liftIO $ dnfun dinput +       diff outDNI1 (closeIsh' 1e-8) outDNC1 +       diff outDNI2 (closeIsh' 1e-8) outDNC2 + +    ,withResource (makeFwdADArtifactCompile env expr) (\_ -> pure ()) $ \getfwdartifactC -> +     testProperty "chad" $ property $ do +       primalfun <- liftIO getprimalfun +       fwdartifactC <- liftIO getfwdartifactC + +       annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -  annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) +       let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr +           dtermChadS = simplifyFix dtermChad0 +           dtermChadS20 = simplifyN 20 dtermChad0 -  let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr -      dtermChadS = simplifyFix dtermChad0 -      dtermChadS20 = simplifyN 20 dtermChad0 +       -- pack Text for less GC pressure (these values are retained for some reason) +       diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) -  -- pack Text for less GC pressure (these values are retained for some reason) -  diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) +       input <- forAllWith (showEnv env) envGenerator -  input <- forAllWith (showEnv env) envGenerator +       let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) +           unpackGrad = unTup vUnpair (d2e env) . Value -  let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) -      unpackGrad = unTup vUnpair (d2e env) . Value +       outPrimal <- liftIO $ primalfun input -  let outPrimalI = interpretOpen False input expr -  outPrimal <- liftIO $ getprimalfun >>= ($ input) -  diff outPrimal (closeIsh' 1e-8) outPrimalI +       let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 +           (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS +           gradChad0' = toTanE env input gradChad0 +           gradChadS' = toTanE env input gradChadS +           scChad = envScalars env gradChad0' +           scChadS = envScalars env gradChadS' +           gradFwd = gradientByForward fwdartifactC input +           scFwd = envScalars env gradFwd -  let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0 -      (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS -      gradChad0' = toTanE env input gradChad0 -      gradChadS' = toTanE env input gradChadS -      scChad = envScalars env gradChad0' -      scChadS = envScalars env gradChadS' -      gradFwd = gradientByForward knownEnv expr input -      scFwd = envScalars env gradFwd +       -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) +       -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) +       -- annotate (ppExpr knownEnv expr) +       -- annotate (ppExpr env dtermChad0) +       -- annotate (ppExpr env dtermChadS) +       diff outChadS closeIsh outChad0 +       diff outChadS closeIsh outPrimal +       diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad +       diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd +    ] -  -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) -  -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) -  -- annotate (ppExpr knownEnv expr) -  -- annotate (ppExpr env dtermChad0) -  -- annotate (ppExpr env dtermChadS) -  diff outChadS closeIsh outChad0 -  diff outChadS closeIsh outPrimal -  diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad -  diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd    where      envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]      envScalars SNil SNil = []      envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs -withCompiled :: KnownEnv env => Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled expr = withResource (compile knownEnv expr) (\_ -> pure ()) +withCompiled :: SList STy env -> Ex env t -> (IO (SList Value env -> IO (Rep t)) -> TestTree) -> TestTree +withCompiled env expr = withResource (compile env expr) (\_ -> pure ())  term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)  term_build1_sum = fromNamed $ lambda #x $ body $ | 
