diff options
| -rw-r--r-- | src/Compile.hs | 114 | ||||
| -rw-r--r-- | src/Language.hs | 8 | ||||
| -rw-r--r-- | src/Language/AST.hs | 12 | ||||
| -rw-r--r-- | test-framework/Test/Framework.hs | 90 | ||||
| -rw-r--r-- | test/Main.hs | 9 |
5 files changed, 198 insertions, 35 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 4e81c6a..f2063ee 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -34,6 +34,7 @@ import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) @@ -939,6 +940,117 @@ compile' env = \case [("buf", CEProj (CELit arrname) "buf") ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -1311,7 +1423,7 @@ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" diff --git a/src/Language.hs b/src/Language.hs index d3c38d6..31b4b87 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -133,6 +133,14 @@ minimum1i e = NEMinimum1Inner e reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) reshape = NEReshape +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3 + +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 325817d..c9d05c9 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -58,6 +58,15 @@ data NExpr env t where NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) + -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t @@ -208,6 +217,9 @@ fromNamedExpr val = \case NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) + NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) NEIdx1 a b -> EIdx1 ext (go a) (go b) diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index 5ceb866..80711b2 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -4,6 +4,7 @@ {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module Test.Framework ( TestTree, @@ -20,11 +21,11 @@ module Test.Framework ( TestName, ) where -import Control.Concurrent (setNumCapabilities, forkIO, killThread, forkOn) +import Control.Concurrent (setNumCapabilities, forkIO, forkOn, killThread) import Control.Concurrent.MVar import Control.Concurrent.STM -import Control.Exception (finally) -import Control.Monad (forM, when, forM_, replicateM_) +import Control.Exception (SomeException, throw, try, throwIO) +import Control.Monad (forM, when, forM_) import Control.Monad.IO.Class import Data.IORef import Data.List (isInfixOf, intercalate) @@ -49,6 +50,9 @@ import Hedgehog.Internal.Runner qualified as H import Hedgehog.Internal.Seed qualified as H.Seed +-- TODO: with GHC 9.12 we have tryWithContext and rethrowIO, which is better for rethrowing exceptions + + type TestName = String data TestTree @@ -107,6 +111,11 @@ filterTree (Options { optsPattern = pat }) = go [] renderPath comps = "^" ++ intercalate "/" (reverse comps) ++ "$" +treeNumTests :: TestTree -> Int +treeNumTests (Group _ _ ts) = sum (map treeNumTests ts) +treeNumTests (Resource _ _ _ fun) = treeNumTests (fun undefined) +treeNumTests HP{} = 1 + computeMaxLen :: TestTree -> Int computeMaxLen = go 0 where @@ -220,7 +229,7 @@ runTests options = \tree' -> else isJust <$> runTreeSeq 0 [] tree stats <- readIORef statsRef endtm <- getCurrentTime - let ?istty = isterm in printStats stats (diffUTCTime endtm starttm) + let ?istty = isterm in printStats (treeNumTests tree) stats (diffUTCTime endtm starttm) return (if success then ExitSuccess else ExitFailure 1) -- | Returns when all jobs in this tree have been scheduled. When all jobs are @@ -275,31 +284,34 @@ runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runRe let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")" path = intercalate "/" (reverse (pathitem : revpath)) idxlist = reverse revidxlist - -- outputConcurrent $ "! " ++ path ++ ": R Submitting\n" submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do setConsoleRegion makeRegion ('|' : path ++ " [R] making...") - -- outputConcurrent $ "! " ++ path ++ ": R Making\n" - value <- make -- TODO: catch exceptions - -- outputConcurrent $ "! " ++ path ++ ": R Made\n" - - -- outputConcurrent $ "! " ++ path ++ ": R Running subtree\n" - suboutvar <- newEmptyMVar - runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion - -- outputConcurrent $ "! " ++ path ++ ": R Scheduled subtree\n" + evalue <- try make + case evalue of + Left (err :: SomeException) -> do + finishConsoleRegion makeRegion $ + ansiRed ++ "Exception building resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err + putMVar outvar False + Right value -> do + suboutvar <- newEmptyMVar + runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion - _ <- forkIO $ do - success <- readMVar suboutvar - -- outputConcurrent $ "! " ++ path ++ ": R Subtree done, scheduling cleanup\n" - poolSubmit ?pool idxlist (Just outvar) $ do - cleanupRegion <- openConsoleRegion Linear - setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...") - -- outputConcurrent $ "! " ++ path ++ ": R Cleaning up\n" - cleanup value -- TODO: catch exceptions - -- outputConcurrent $ "! " ++ path ++ ": R Cleanup done\n" - closeConsoleRegion cleanupRegion - return success - return () + _ <- forkIO $ do + success <- readMVar suboutvar + poolSubmit ?pool idxlist Nothing $ do + cleanupRegion <- openConsoleRegion Linear + setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...") + eres <- try (cleanup value) + case eres of + Left (err :: SomeException) -> do + finishConsoleRegion cleanupRegion $ + ansiRed ++ "Exception cleaning up resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err + putMVar outvar False + Right () -> do + closeConsoleRegion cleanupRegion + putMVar outvar success + return () runResource mparregion _ tree outvar = runTreePar mparregion revidxlist revpath tree outvar runTreePar mparregion revidxlist revpath (HP name prop) outvar = do @@ -407,15 +419,19 @@ renderResult report path timeTaken = do return (str ++ suffix) _ -> return str -printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO () -printStats stats timeTaken - | statsOK stats == statsTotal stats = do +printStats :: (?istty :: Bool) => Int -> Stats -> NominalDiffTime -> IO () +printStats numTests stats timeTaken + | statsOK stats == numTests = do putStrLn $ ansiGreen ++ "All " ++ show (statsTotal stats) ++ " tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + | statsOK stats == statsTotal stats = + putStrLn $ ansiRed ++ "Failed (" ++ show (numTests - statsTotal stats) ++ " tests could not run)." ++ + prettyDuration True (realToFrac timeTaken) ++ ansiReset | otherwise = let nfailed = statsTotal stats - statsOK stats - in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ - " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset + in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ " tests" ++ + (if statsTotal stats /= numTests then " (" ++ show (numTests - statsTotal stats) ++ " could not run)" else "") ++ + "." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset newtype WorkerPool k = WorkerPool (TVar (PQ.MinPQueue k (Terminate PoolJob))) @@ -427,9 +443,17 @@ withWorkerPool :: Ord k => Int -> (WorkerPool k -> IO a) -> IO a withWorkerPool numWorkers k = do chan <- newTVarIO PQ.empty threads <- forM [0..numWorkers-1] (\i -> forkOn i (worker i chan)) - k (WorkerPool chan) `finally` do - replicateM_ numWorkers (atomically $ writeTVar chan PQ.empty) - forM_ threads killThread + eres <- try (k (WorkerPool chan)) + case eres of + Left (err :: SomeException) -> do + atomically $ writeTVar chan PQ.empty + forM_ threads killThread + throw err + Right res -> do + readTVarIO chan >>= \case + PQ.Empty -> return () + _ -> throwIO (userError "withWorkerPool: computation exited before all jobs were handled") + return res where worker :: Ord k => Int -> TVar (PQ.MinPQueue k (Terminate PoolJob)) -> IO () worker idx chan = do diff --git a/test/Main.hs b/test/Main.hs index 2acc9f8..4bc9082 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -396,7 +396,7 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) - annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) + annotate (ppExpr env dtermSChadSUSP) diff outChad0 closeIsh outPrimal diff outChadS closeIsh outPrimal diff outChadSUS closeIsh outPrimal @@ -562,6 +562,13 @@ tests_Compile = testGroup "Compile" nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil + + ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $ + fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a + + ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $ + let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $ + fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d ] tests_AD :: TestTree |
