diff options
Diffstat (limited to 'test/Main.hs')
| -rw-r--r-- | test/Main.hs | 209 |
1 files changed, 133 insertions, 76 deletions
diff --git a/test/Main.hs b/test/Main.hs index 0a57cbf..05597cc 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} @@ -11,35 +12,38 @@ {-# LANGUAGE UndecidableInstances #-} module Main where +import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) -import qualified Data.Map.Strict as Map -import qualified Data.Text as T +import Data.Map.Strict qualified as Map +import Data.Text qualified as T import Hedgehog -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range import Test.Framework -import Array -import AST hiding ((.>)) -import AST.Pretty -import AST.UnMonoid -import CHAD.Top -import CHAD.Types -import CHAD.Types.ToTan -import Compile -import qualified Example -import qualified Example.GMM as Example -import Example.Types -import ForwardAD -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep -import Language -import Simplify +import CHAD.Array +import CHAD.AST hiding ((.>)) +import CHAD.AST.Count (pruneExpr) +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Drev.Types.ToTan +import CHAD.Example qualified as Example +import CHAD.Example.GMM qualified as Example +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep +import CHAD.Language +import CHAD.Simplify data TypedValue t = TypedValue (STy t) (Rep t) @@ -63,18 +67,18 @@ simplifyIters iters env | Dict <- envKnown env = SimplIters n -> simplifyN n SimplFix -> simplifyFix --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD simplIters env term input = - let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False env input dterm - in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) +-- gradientByCHAD simplIters env term input = +-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term +-- (out, grad) = interpretOpen False env input dterm +-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' simplIters env term input = - second (second (toTanE env input)) $ - gradientByCHAD simplIters env term input +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) +-- gradientByCHAD' simplIters env term input = +-- second (second (toTanE env input)) $ +-- gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 @@ -92,8 +96,8 @@ extendDN (STMaybe _) Nothing = pure Nothing extendDN (STMaybe t) (Just x) = Just <$> extendDN t x extendDN (STArr _ t) arr = traverse (extendDN t) arr extendDN (STScal sty) x = case sty of - STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) - STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) + STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) + STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d) STI32 -> pure x STI64 -> pure x STBool -> pure x @@ -217,8 +221,8 @@ genShape = \n tpl -> do shapeDiv :: Shape n -> DimNames n -> Int -> Shape n shapeDiv ShNil _ _ = ShNil - shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) - shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) + shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` max lo (n `div` f) + shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` max lo (n `div` f) shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f) shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f) @@ -240,8 +244,8 @@ genValue topty tpl = case topty of ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of - STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) - STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10) + STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10) STI32 -> genInt STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] @@ -302,7 +306,7 @@ adTestGen name expr envGenerator = exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name + groupSetCollapse $ testGroup name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS ,testGroup "chad" @@ -349,17 +353,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS dtermSChadS = simplifyFix dtermSChad0 dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS + dtermSChadSUSP = simplifyFix $ pruneExpr env dtermSChadSUS in - withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env dtermSChadSUS $ \dcompSChadSUS -> + withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS + when (not (null output)) $ + outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>" + return fun) $ \fwdartifactC -> + withCompiled env dtermSChadSUSP $ \dcompSChadSUSP -> testProperty testname $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason) - diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0))) - diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) - diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0))) + let dtermChad20 = simplifyN 20 dtermChad0 + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20)) + diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20))) + let dtermSChad20 = simplifyN 20 dtermSChad0 + diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20)) + diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input @@ -369,46 +379,54 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input - let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 - (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS - (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS - (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 - (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS - tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 - tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS - tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS - tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 - tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS - tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 + (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS + (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS + (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 + (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS + (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS + (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS + tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS + tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP - (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input) - let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS + (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input) + let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP -- annotate (showEnv (d2e env) gradChad0) -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) - diff outChad0 closeIsh outPrimal - diff outChadS closeIsh outPrimal - diff outChadSUS closeIsh outPrimal - diff outSChad0 closeIsh outPrimal - diff outSChadS closeIsh outPrimal - diff outSChadSUS closeIsh outPrimal - diff outCompSChadSUS closeIsh outPrimal + annotate (ppExpr env dtermSChadSUSP) + diff outChad0 closeIsh outPrimal + diff outChadS closeIsh outPrimal + diff outChadSUS closeIsh outPrimal + diff outSChad0 closeIsh outPrimal + diff outSChadS closeIsh outPrimal + diff outSChadSUS closeIsh outPrimal + diff outSChadSUSP closeIsh outPrimal + diff outCompSChadSUSP closeIsh outPrimal let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) - diff tansChad closeIshE' tansFwd - diff tansChadS closeIshE' tansFwd - diff tansChadSUS closeIshE' tansFwd - diff tansSChad closeIshE' tansFwd - diff tansSChadS closeIshE' tansFwd - diff tansSChadSUS closeIshE' tansFwd - diff tansCompSChadSUS closeIshE' tansFwd + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansChadSUS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansSChadSUS closeIshE' tansFwd + diff tansSChadSUSP closeIshE' tansFwd + diff tansCompSChadSUSP closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +withCompiled env expr = withResource' $ do + (fun, output) <- compile env expr + when (not (null output)) $ + outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) gen_gmm = do @@ -423,7 +441,7 @@ gen_gmm = do vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) - vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) + vgamma <- Gen.realFloat (Range.constant (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma @@ -554,6 +572,13 @@ tests_Compile = testGroup "Compile" nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree @@ -659,6 +684,38 @@ tests_AD = testGroup "AD" ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm + + ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree + + ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #n (snd_ #sh * snd_ (fst_ #sh)) $ + idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a + + ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $ + let_ #sh (shape #a) $ + let_ #innern (snd_ #sh) $ + let_ #n (#innern * snd_ (fst_ #sh)) $ + let_ #flata (reshape (SS SZ) (pair nil #n) #a) $ + -- ensure the input array to EReshape is shared + idx0 $ sum1i $ + build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern)) + + ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a + + ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a + + ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + let_ #v 2 $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a + + ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a + + ,adTest "map" $ fromNamed $ lambda @(TArr N1 R) #a $ body $ + idx0 $ sum1i $ map_ (#x :-> 2 * #x) #a ] main :: IO () |
