From 59ea6579c0cceeecaef7c27e39aab39828a4fbeb Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 20 Jun 2025 00:02:11 +0200 Subject: WIP parallel test suite --- chad-fast.cabal | 2 + test-framework/Test/Framework.hs | 294 ++++++++++++++++++++++++++++++--------- test/Main.hs | 31 +++-- 3 files changed, 249 insertions(+), 78 deletions(-) diff --git a/chad-fast.cabal b/chad-fast.cabal index b7270e4..94af651 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -83,7 +83,9 @@ library test-framework exposed-modules: Test.Framework build-depends: base, + concurrent-output, hedgehog, + stm, time, transformers hs-source-dirs: test-framework diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index 1b2b7d7..2d45630 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -6,7 +6,7 @@ module Test.Framework ( TestTree, testGroup, - testGroupCollapse, + groupSetCollapse, groupSetSequential, testProperty, withResource, withResource', @@ -18,13 +18,18 @@ module Test.Framework ( TestName, ) where -import Control.Monad (forM, when) -import Control.Monad.Trans.State.Strict +import Control.Concurrent +import Control.Exception (finally) +import Control.Monad (forM, when, replicateM) import Control.Monad.IO.Class +import Data.IORef import Data.List (isInfixOf, intercalate) import Data.Maybe (isJust, mapMaybe, fromJust) import Data.String (fromString) import Data.Time.Clock +import GHC.Conc (getNumProcessors) +import System.Console.Concurrent +import System.Console.Regions import System.Environment import System.Exit import System.IO (hFlush, hPutStrLn, stdout, stderr, hIsTerminalDevice) @@ -38,18 +43,31 @@ import qualified Hedgehog.Internal.Runner as H import qualified Hedgehog.Internal.Seed as H.Seed +type TestName = String + data TestTree - = Group Bool String [TestTree] + = Group GroupOpts String [TestTree] | forall a. Resource (IO a) (a -> IO ()) (a -> TestTree) | HP String H.Property -type TestName = String +data GroupOpts = GroupOpts + { goCollapse :: Bool + , goSequential :: Bool } + deriving (Show) + +defaultGroupOpts :: GroupOpts +defaultGroupOpts = GroupOpts False False testGroup :: String -> [TestTree] -> TestTree -testGroup = Group False +testGroup = Group defaultGroupOpts -testGroupCollapse :: String -> [TestTree] -> TestTree -testGroupCollapse = Group True +groupSetCollapse :: TestTree -> TestTree +groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name trees +groupSetCollapse _ = error "groupSetCollapse: not called on a Group" + +groupSetSequential :: TestTree -> TestTree +groupSetSequential (Group opts name trees) = Group opts { goSequential = True } name trees +groupSetSequential _ = error "groupSetSequential: not called on a Group" -- | The @a -> TestTree@ function must use the @a@ only inside properties: when -- not actually running properties, it will be passed 'undefined'. @@ -84,8 +102,11 @@ computeMaxLen :: TestTree -> Int computeMaxLen = go 0 where go :: Int -> TestTree -> Int - go indent (Group True name trees) = maximum (2*indent + length name : map (go (indent+1)) trees) - go indent (Group False _ trees) = maximum (0 : map (go (indent+1)) trees) + go indent (Group opts name trees) + -- If we collapse, the name of the group gets prefixed before the final status message after collapsing. + | goCollapse opts = maximum (2*indent + length name : map (go (indent+1)) trees) + -- If we don't collapse, the group name does get printed but without any status message, so it doesn't need to get accounted for in maxlen. + | otherwise = maximum (0 : map (go (indent+1)) trees) go indent (Resource _ _ fun) = go indent (fun undefined) go indent (HP name _) = 2 * indent + length name @@ -97,22 +118,20 @@ data Stats = Stats initStats :: Stats initStats = Stats 0 0 -newtype M a = M (StateT Stats IO a) - deriving newtype (Functor, Applicative, Monad, MonadIO) - -modifyStats :: (Stats -> Stats) -> M () -modifyStats f = M (modify f) +modifyStats :: (?stats :: IORef Stats) => (Stats -> Stats) -> IO () +modifyStats f = atomicModifyIORef' ?stats (\s -> (f s, ())) data Options = Options { optsPattern :: String , optsHelp :: Bool , optsHedgehogReplay :: Maybe (H.Skip, H.Seed) , optsHedgehogShrinks :: Maybe Int + , optsParallel :: Bool } deriving (Show) defaultOptions :: Options -defaultOptions = Options "" False Nothing Nothing +defaultOptions = Options "" False Nothing Nothing False parseOptions :: [String] -> Options -> Either String Options parseOptions [] opts = pure opts @@ -134,6 +153,7 @@ parseOptions ("--hedgehog-shrinks":arg:args) opts = case readMaybe arg of Just n -> parseOptions args opts { optsHedgehogShrinks = Just n } Nothing -> Left "Invalid argument to '--hedgehog-shrinks'" +parseOptions ("--parallel":args) opts = parseOptions args opts { optsParallel = True } parseOptions (arg:_) _ = Left $ "Unrecognised argument: '" ++ arg ++ "'" printUsage :: IO () @@ -147,7 +167,10 @@ printUsage = do ," test looks like: '^group1/group2/testname$'." ," --hedgehog-replay '{skip} {seed}'" ," Skip to a particular generated Hedgehog test. Should be used" - ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set."] + ," with -p. Overrides 'propertySkip' in 'PropertyConfig' if set." + ," --hedgehog-shrinks NUM" + ," Limit the number of shrinking steps." + ," --parallel Run tests in parallel."] defaultMain :: TestTree -> IO () defaultMain tree = do @@ -165,58 +188,161 @@ runTests options = \tree' -> return (ExitFailure 1) Just tree -> do isterm <- hIsTerminalDevice stdout - let M m = let ?maxlen = computeMaxLen tree - ?istty = isterm - in go 0 id tree starttm <- getCurrentTime - (success, stats) <- runStateT m initStats + statsRef <- newIORef initStats + success <- let ?stats = statsRef + ?options = options + ?maxlen = computeMaxLen tree + ?istty = isterm + in if optsParallel options + then do nproc <- getNumProcessors + setNumCapabilities nproc + displayConsoleRegions $ + withWorkerPool nproc $ \ pool -> + let ?pool = pool in runTreePar Nothing 0 id tree + else isJust <$> runTreeSeq 0 id tree + stats <- readIORef statsRef endtm <- getCurrentTime let ?istty = isterm in printStats stats (diffUTCTime endtm starttm) - return (if isJust success then ExitSuccess else ExitFailure 1) - where - -- If all tests are successful, returns the number of output lines produced - go :: (?maxlen :: Int, ?istty :: Bool) => Int -> (String -> String) -> TestTree -> M (Maybe Int) - go indent path (Group collapse name trees) = do - liftIO $ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout - starttm <- liftIO getCurrentTime - mlns <- fmap (fmap sum . sequence) . forM trees $ - go (indent + 1) (path . (name++) . ('/':)) - endtm <- liftIO getCurrentTime - case mlns of - Just lns | collapse, ?istty -> do - let thislen = 2*indent + length name - liftIO $ putStrLn $ concat (replicate (lns+1) "\x1B[A\x1B[2K") ++ "\x1B[G" ++ - replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ - "\x1B[32mOK\x1B[0m" ++ - prettyDuration False (realToFrac (diffUTCTime endtm starttm)) - return (Just 1) - _ -> return ((+1) <$> mlns) - go indent path (Resource make cleanup fun) = do - value <- liftIO make - success <- go indent path (fun value) - liftIO $ cleanup value - return success - go indent path (HP name (H.Property config test)) = do + return (if success then ExitSuccess else ExitFailure 1) + +-- If all tests are successful, returns the number of output lines produced +runTreeSeq :: (?options :: Options, ?stats :: IORef Stats,?maxlen :: Int, ?istty :: Bool) + => Int -> (String -> String) -> TestTree -> IO (Maybe Int) +runTreeSeq indent path (Group groupOpts name trees) = do + liftIO $ putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout + starttm <- liftIO getCurrentTime + mlns <- fmap (fmap sum . sequence) . forM trees $ + runTreeSeq (indent + 1) (path . (name++) . ('/':)) + endtm <- liftIO getCurrentTime + case mlns of + Just lns | goCollapse groupOpts, ?istty -> do let thislen = 2*indent + length name - let outputPrefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' - when ?istty $ liftIO $ putStr outputPrefix >> hFlush stdout - - let (config', seedfun) = applyHedgehogOptions options config - seed <- seedfun + liftIO $ putStrLn $ concat (replicate (lns+1) "\x1B[A\x1B[2K") ++ "\x1B[G" ++ + replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ + "\x1B[32mOK\x1B[0m" ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + return (Just 1) + _ -> return ((+1) <$> mlns) +runTreeSeq indent path (Resource make cleanup fun) = do + value <- liftIO make + success <- runTreeSeq indent path (fun value) + liftIO $ cleanup value + return success +runTreeSeq indent path (HP name prop) = + runHP (\prefix -> when ?istty $ putStr prefix >> hFlush stdout) + (\_ -> outputProgress (?maxlen + 2)) + (\prefix rendered -> putStrLn ((if ?istty then "\x1B[K" else prefix) ++ rendered) >> hFlush stdout) + indent path name prop + +-- Assumes it's run within displayConsoleRegions. +runTreePar :: (?options :: Options, ?stats :: IORef Stats, ?pool :: WorkerPool, ?maxlen :: Int, ?istty :: Bool) + => Maybe (ConsoleRegion, String) -> Int -> (String -> String) -> TestTree -> IO Bool +runTreePar mregctx indent path (Group groupOpts name trees) = do + let run reg regPrefix sequential = do + setConsoleRegion reg name + starttm <- liftIO getCurrentTime + success <- fmap and . poolRunList ?pool . flip map trees $ + runTreeParSub reg (name ++ " > ") (indent + 1) (path . (name++) . ('/':)) + endtm <- liftIO getCurrentTime + + let thislen = 2*indent + length name + finishConsoleRegion reg $ + replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ + ansi "\x1B[32mOK\x1B[0m" ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + return success + + case (mregctx, goSequential groupOpts) of + (Nothing, True) -> do + outputConcurrent (replicate (2 * indent) ' ' ++ name ++ "\n") + fmap and . forM trees $ + runTreePar Nothing (indent + 1) (path . (name++) . ('/':)) + (_, False) -> do + regPrefix <- case mregctx of + Just (reg, regPrefix) -> do + setConsoleRegion reg (regPrefix ++ name) + return regPrefix + Nothing -> return "" starttm <- liftIO getCurrentTime - report <- liftIO $ H.checkReport config' 0 seed test (outputProgress (?maxlen + 2)) + success <- fmap and . poolRunList ?pool . flip map trees $ \tree -> + withConsoleRegion Linear $ \reg -> + runTreePar (Just (reg, regPrefix ++ name ++ " > ")) + (indent + 1) (path . (name++) . ('/':)) tree endtm <- liftIO getCurrentTime - liftIO $ do - when (not ?istty) $ putStr outputPrefix - printResult report (path name) (diffUTCTime endtm starttm) - hFlush stdout - - let ok = H.reportStatus report == H.OK - modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats - , statsTotal = 1 + statsTotal stats } - return (if ok then Just 1 else Nothing) + let thislen = 2*indent + length name + finishConsoleRegion reg $ + replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++ + ansi "\x1B[32mOK\x1B[0m" ++ + prettyDuration False (realToFrac (diffUTCTime endtm starttm)) + return success + (Just (reg, regPrefix), sequential) -> + run reg regPrefix sequential + +runTreePar mregctx indent path (Resource make cleanup fun) = do + value <- liftIO make + success <- runTreePar mregctx indent path (fun value) + liftIO $ cleanup value + return success + +runTreePar mregctx indent path (HP name prop) = + let run reg regPrefix = + isJust <$> + runHP (\prefix -> setConsoleRegion reg (regPrefix ++ prefix)) + (\prefix -> outputProgressPar reg (regPrefix ++ prefix)) + (\prefix rendered -> finishConsoleRegion reg (regPrefix ++ prefix ++ rendered ++ "\n")) + indent path name prop + in case mregctx of + Nothing -> withConsoleRegion Linear $ \reg -> run reg "" + Just (reg, regPrefix) -> run reg regPrefix + +-- Sequential subcomputation in a parallel environment +runTreeParSub :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool) + => ConsoleRegion -> String -> Int -> (String -> String) -> TestTree -> IO Bool +runTreeParSub region regPrefix indent path (Group _ name trees) = + fmap and . forM trees $ + runTreeParSub region (regPrefix ++ name ++ " > ") (indent + 1) (path . (name++) . ('/':)) + +runTreeParSub region regPrefix indent path (Resource make cleanup fun) = do + value <- liftIO make + success <- runTreeParSub region regPrefix indent path (fun value) + liftIO $ cleanup value + return success + +runTreeParSub region regPrefix indent path (HP name prop) = do + isJust <$> + runHP (\prefix -> setConsoleRegion region (regPrefix ++ prefix)) + (\prefix -> outputProgressPar region (regPrefix ++ prefix)) + (\prefix rendered -> finishConsoleRegion region (regPrefix ++ prefix ++ rendered)) + indent path name prop + +runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool) + => (String -> IO ()) + -> (String -> H.Report H.Progress -> IO ()) + -> (String -> String -> IO ()) + -> Int -> (String -> String) + -> String -> H.Property -> IO (Maybe Int) +runHP prefixPrinter progressPrinter resultPrinter indent path name (H.Property config test) = do + let thislen = 2*indent + length name + let outputPrefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' + liftIO $ prefixPrinter outputPrefix + + let (config', seedfun) = applyHedgehogOptions ?options config + seed <- seedfun + + starttm <- liftIO getCurrentTime + report <- liftIO $ H.checkReport config' 0 seed test (progressPrinter outputPrefix) + endtm <- liftIO getCurrentTime + + rendered <- liftIO $ renderResult report (path name) (diffUTCTime endtm starttm) + liftIO $ resultPrinter outputPrefix rendered + + let ok = H.reportStatus report == H.OK + modifyStats $ \stats -> stats { statsOK = fromEnum ok + statsOK stats + , statsTotal = 1 + statsTotal stats } + return (if ok then Just 1 else Nothing) applyHedgehogOptions :: MonadIO m => Options -> H.PropertyConfig -> (H.PropertyConfig, m H.Seed) applyHedgehogOptions opts config0 = @@ -236,18 +362,23 @@ outputProgress indent report hFlush stdout | otherwise = return () -printResult :: (?istty :: Bool) => H.Report H.Result -> String -> NominalDiffTime -> IO () -printResult report path timeTaken = do +outputProgressPar :: ConsoleRegion -> String -> H.Report H.Progress -> IO () +outputProgressPar region prefix report = do + str <- H.renderProgress H.EnableColor (Just (fromString "")) report + setConsoleRegion region (prefix ++ replace '\n' " " str) + +renderResult :: (?istty :: Bool) => H.Report H.Result -> String -> NominalDiffTime -> IO String +renderResult report path timeTaken = do str <- H.renderResult H.EnableColor (Just (fromString "")) report case H.reportStatus report of - H.OK -> putStrLn (ansi "\x1B[K" ++ str ++ prettyDuration False (realToFrac timeTaken)) + H.OK -> return (str ++ prettyDuration False (realToFrac timeTaken)) H.Failed failure -> do let H.Report { H.reportTests = count, H.reportDiscards = discards } = report replayInfo = H.skipCompress (H.SkipToShrink count discards (H.failureShrinkPath failure)) ++ " " ++ show (H.reportSeed report) suffix = "\n Flags to reproduce: `-p '" ++ path ++ "' --hedgehog-replay '" ++ replayInfo ++ "'`" - putStrLn (ansi "\x1B[K" ++ str ++ suffix) - _ -> putStrLn (ansi "\x1B[K" ++ str) + return (str ++ suffix) + _ -> return str printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO () printStats stats timeTaken @@ -259,6 +390,37 @@ printStats stats timeTaken in putStrLn $ ansi "\x1B[31m" ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansi "\x1B[0m" +data WorkerPool = WorkerPool (Chan (Maybe PoolJob)) [ThreadId] +data PoolJob = forall a. PoolJob (IO a) (MVar a) + +withWorkerPool :: Int -> (WorkerPool -> IO a) -> IO a +withWorkerPool numWorkers k = do + chan <- newChan + pool <- WorkerPool chan <$> forM [0..numWorkers-1] (\i -> forkOn i (worker i chan)) + k pool `finally` replicateM numWorkers (writeChan chan Nothing) + where + worker :: Int -> Chan (Maybe PoolJob) -> IO () + worker idx chan = do + mjob <- readChan chan + case mjob of + Just (PoolJob action mvar) -> do + outputConcurrent $ "[" ++ show idx ++ "] got job\n" + action >>= putMVar mvar + worker idx chan + Nothing -> return () + +poolSubmit :: WorkerPool -> IO a -> MVar a -> IO () +poolSubmit (WorkerPool chan _) action mvar = writeChan chan (Just (PoolJob action mvar)) + +poolRunList :: WorkerPool -> [IO a] -> IO [a] +poolRunList pool actions = do + vars <- forM actions $ \act -> do + var <- newEmptyMVar + poolSubmit pool act var + return var + mapM takeMVar vars + + prettyDuration :: Bool -> Double -> String prettyDuration False x | x < 0.5 = "" prettyDuration _ x = diff --git a/test/Main.hs b/test/Main.hs index 0a57cbf..4cdab1c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -20,6 +20,9 @@ import qualified Data.Map.Strict as Map import qualified Data.Text as T import Hedgehog import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Internal.Gen as IGen +import qualified Hedgehog.Internal.Tree as ITree +import qualified Hedgehog.Internal.Seed as ISeed import qualified Hedgehog.Range as Range import Test.Framework @@ -40,6 +43,7 @@ import Interpreter import Interpreter.Rep import Language import Simplify +import Data.Maybe (fromJust) data TypedValue t = TypedValue (STy t) (Rep t) @@ -63,18 +67,18 @@ simplifyIters iters env | Dict <- envKnown env = SimplIters n -> simplifyN n SimplFix -> simplifyFix --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD simplIters env term input = - let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - (out, grad) = interpretOpen False env input dterm - in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) +-- gradientByCHAD simplIters env term input = +-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term +-- (out, grad) = interpretOpen False env input dterm +-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) --- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' simplIters env term input = - second (second (toTanE env input)) $ - gradientByCHAD simplIters env term input +-- -- In addition to the gradient, also returns the pretty-printed differentiated term. +-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) +-- gradientByCHAD' simplIters env term input = +-- second (second (toTanE env input)) $ +-- gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 @@ -302,7 +306,7 @@ adTestGen name expr envGenerator = exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> - testGroupCollapse name + groupSetCollapse $ groupSetSequential $ testGroup name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS ,testGroup "chad" @@ -661,6 +665,9 @@ tests_AD = testGroup "AD" ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm ] +gmminp :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] +gmminp = ITree.treeValue $ fromJust $ IGen.evalGen 30 (ISeed.from 3) gen_gmm + main :: IO () main = defaultMain $ testGroup "All" [tests_Compile -- cgit v1.2.3-70-g09d2