aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Compile.hs114
-rw-r--r--src/Language.hs8
-rw-r--r--src/Language/AST.hs12
-rw-r--r--test-framework/Test/Framework.hs90
-rw-r--r--test/Main.hs9
5 files changed, 198 insertions, 35 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 4e81c6a..f2063ee 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -34,6 +34,7 @@ import Foreign
import GHC.Exts (int2Word#, addr2Int#)
import GHC.Num (integerFromWord#)
import GHC.Ptr (Ptr(..))
+import GHC.Stack (HasCallStack)
import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
import System.IO.Error (mkIOError, userErrorType)
@@ -939,6 +940,117 @@ compile' env = \case
[("buf", CEProj (CELit arrname) "buf")
,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+ EFold1InnerD1 _ commut efun ex0 earr -> do
+ let STArr (SS n) t = typeOf earr
+ STPair _ bty = typeOf efun
+
+ x0name <- compileAssign "foldd1x0" env ex0
+ arrname <- compileAssign "foldd1arr" env earr
+
+ zeroRefcountCheck (typeOf earr) "fold1iD1" arrname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname)
+ storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname)
+
+ ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "tot"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
+ arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t) accvar (CELit x0name))
+ <> x0incrStmts -- we're copying x0 here
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the array element
+ -- and the accumulator. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the array element.
+ arreltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd1x0" Decrement t x0name
+ incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname
+
+ strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty))
+ return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)])
+
+ EFold1InnerD2 _ commut efun estores ectg -> do
+ let STArr n t2 = typeOf ectg
+ STArr _ bty = typeOf estores
+
+ storesname <- compileAssign "foldd2stores" env estores
+ ctgname <- compileAssign "foldd2ctg" env ectg
+
+ zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]"))
+
+ shsz1name <- genName' "shszN"
+ emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape
+ shsz2name <- genName' "shszSN"
+ emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname))
+
+ x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname)
+ outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname)
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+
+ accvar <- genName' "acc"
+ let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar
+ storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]"
+ ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]"
+ (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun
+ funresvar <- genName' "res"
+ ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit
+ ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
+ pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit))
+ <> ctgeltIncrStmts
+ -- we need to loop in reverse here, but we let jvar run in the
+ -- forward direction so that we can use SLoop. Note jvar is
+ -- reversed in eltidx above
+ <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ -- The combination function will consume the accumulator
+ -- and the stores element. The accumulator is replaced by
+ -- what comes out of the function anyway, so that's
+ -- fine, but we do need to increment the stores element.
+ storeseltIncrStmts
+ <> funStmts
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
+ <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
+ <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
+
+ incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname
+ incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname
+
+ strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2))
+ return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)])
+
EConst _ t x -> return $ CELit $ compileScal True t x
EIdx0 _ e -> do
@@ -1311,7 +1423,7 @@ data AllocMethod = Malloc | Calloc
deriving (Show)
-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
-allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
+allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String
allocArray marker method nameBase rank eltty mshsz shape = do
when (length shape /= fromSNat rank) $
error "allocArray: shape does not match rank"
diff --git a/src/Language.hs b/src/Language.hs
index d3c38d6..31b4b87 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -133,6 +133,14 @@ minimum1i e = NEMinimum1Inner e
reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
reshape = NEReshape
+fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3
+
+fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
+ -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3
+
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 325817d..c9d05c9 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -58,6 +58,15 @@ data NExpr env t where
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+ NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b)
+ -> NExpr env t1
+ -> NExpr env (TArr (S n) t1)
+ -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+ NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2)
+ -> NExpr env (TArr (S n) b)
+ -> NExpr env (TArr n t2)
+ -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
+
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
@@ -208,6 +217,9 @@ fromNamedExpr val = \case
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEReshape n a b -> EReshape ext n (go a) (go b)
+ NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
NEIdx1 a b -> EIdx1 ext (go a) (go b)
diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs
index 5ceb866..80711b2 100644
--- a/test-framework/Test/Framework.hs
+++ b/test-framework/Test/Framework.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Test.Framework (
TestTree,
@@ -20,11 +21,11 @@ module Test.Framework (
TestName,
) where
-import Control.Concurrent (setNumCapabilities, forkIO, killThread, forkOn)
+import Control.Concurrent (setNumCapabilities, forkIO, forkOn, killThread)
import Control.Concurrent.MVar
import Control.Concurrent.STM
-import Control.Exception (finally)
-import Control.Monad (forM, when, forM_, replicateM_)
+import Control.Exception (SomeException, throw, try, throwIO)
+import Control.Monad (forM, when, forM_)
import Control.Monad.IO.Class
import Data.IORef
import Data.List (isInfixOf, intercalate)
@@ -49,6 +50,9 @@ import Hedgehog.Internal.Runner qualified as H
import Hedgehog.Internal.Seed qualified as H.Seed
+-- TODO: with GHC 9.12 we have tryWithContext and rethrowIO, which is better for rethrowing exceptions
+
+
type TestName = String
data TestTree
@@ -107,6 +111,11 @@ filterTree (Options { optsPattern = pat }) = go []
renderPath comps = "^" ++ intercalate "/" (reverse comps) ++ "$"
+treeNumTests :: TestTree -> Int
+treeNumTests (Group _ _ ts) = sum (map treeNumTests ts)
+treeNumTests (Resource _ _ _ fun) = treeNumTests (fun undefined)
+treeNumTests HP{} = 1
+
computeMaxLen :: TestTree -> Int
computeMaxLen = go 0
where
@@ -220,7 +229,7 @@ runTests options = \tree' ->
else isJust <$> runTreeSeq 0 [] tree
stats <- readIORef statsRef
endtm <- getCurrentTime
- let ?istty = isterm in printStats stats (diffUTCTime endtm starttm)
+ let ?istty = isterm in printStats (treeNumTests tree) stats (diffUTCTime endtm starttm)
return (if success then ExitSuccess else ExitFailure 1)
-- | Returns when all jobs in this tree have been scheduled. When all jobs are
@@ -275,31 +284,34 @@ runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runRe
let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")"
path = intercalate "/" (reverse (pathitem : revpath))
idxlist = reverse revidxlist
- -- outputConcurrent $ "! " ++ path ++ ": R Submitting\n"
submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do
setConsoleRegion makeRegion ('|' : path ++ " [R] making...")
- -- outputConcurrent $ "! " ++ path ++ ": R Making\n"
- value <- make -- TODO: catch exceptions
- -- outputConcurrent $ "! " ++ path ++ ": R Made\n"
-
- -- outputConcurrent $ "! " ++ path ++ ": R Running subtree\n"
- suboutvar <- newEmptyMVar
- runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion
- -- outputConcurrent $ "! " ++ path ++ ": R Scheduled subtree\n"
+ evalue <- try make
+ case evalue of
+ Left (err :: SomeException) -> do
+ finishConsoleRegion makeRegion $
+ ansiRed ++ "Exception building resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err
+ putMVar outvar False
+ Right value -> do
+ suboutvar <- newEmptyMVar
+ runResource (Just makeRegion) (depth + 1) (fun value) suboutvar -- will consume makeRegion
- _ <- forkIO $ do
- success <- readMVar suboutvar
- -- outputConcurrent $ "! " ++ path ++ ": R Subtree done, scheduling cleanup\n"
- poolSubmit ?pool idxlist (Just outvar) $ do
- cleanupRegion <- openConsoleRegion Linear
- setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...")
- -- outputConcurrent $ "! " ++ path ++ ": R Cleaning up\n"
- cleanup value -- TODO: catch exceptions
- -- outputConcurrent $ "! " ++ path ++ ": R Cleanup done\n"
- closeConsoleRegion cleanupRegion
- return success
- return ()
+ _ <- forkIO $ do
+ success <- readMVar suboutvar
+ poolSubmit ?pool idxlist Nothing $ do
+ cleanupRegion <- openConsoleRegion Linear
+ setConsoleRegion cleanupRegion ('|' : path ++ " [R] cleanup...")
+ eres <- try (cleanup value)
+ case eres of
+ Left (err :: SomeException) -> do
+ finishConsoleRegion cleanupRegion $
+ ansiRed ++ "Exception cleaning up resource at " ++ path ++ ":" ++ ansiReset ++ "\n" ++ show err
+ putMVar outvar False
+ Right () -> do
+ closeConsoleRegion cleanupRegion
+ putMVar outvar success
+ return ()
runResource mparregion _ tree outvar = runTreePar mparregion revidxlist revpath tree outvar
runTreePar mparregion revidxlist revpath (HP name prop) outvar = do
@@ -407,15 +419,19 @@ renderResult report path timeTaken = do
return (str ++ suffix)
_ -> return str
-printStats :: (?istty :: Bool) => Stats -> NominalDiffTime -> IO ()
-printStats stats timeTaken
- | statsOK stats == statsTotal stats = do
+printStats :: (?istty :: Bool) => Int -> Stats -> NominalDiffTime -> IO ()
+printStats numTests stats timeTaken
+ | statsOK stats == numTests = do
putStrLn $ ansiGreen ++ "All " ++ show (statsTotal stats) ++
" tests passed." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset
+ | statsOK stats == statsTotal stats =
+ putStrLn $ ansiRed ++ "Failed (" ++ show (numTests - statsTotal stats) ++ " tests could not run)." ++
+ prettyDuration True (realToFrac timeTaken) ++ ansiReset
| otherwise =
let nfailed = statsTotal stats - statsOK stats
- in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++
- " tests." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset
+ in putStrLn $ ansiRed ++ "Failed " ++ show nfailed ++ " out of " ++ show (statsTotal stats) ++ " tests" ++
+ (if statsTotal stats /= numTests then " (" ++ show (numTests - statsTotal stats) ++ " could not run)" else "") ++
+ "." ++ prettyDuration True (realToFrac timeTaken) ++ ansiReset
newtype WorkerPool k = WorkerPool (TVar (PQ.MinPQueue k (Terminate PoolJob)))
@@ -427,9 +443,17 @@ withWorkerPool :: Ord k => Int -> (WorkerPool k -> IO a) -> IO a
withWorkerPool numWorkers k = do
chan <- newTVarIO PQ.empty
threads <- forM [0..numWorkers-1] (\i -> forkOn i (worker i chan))
- k (WorkerPool chan) `finally` do
- replicateM_ numWorkers (atomically $ writeTVar chan PQ.empty)
- forM_ threads killThread
+ eres <- try (k (WorkerPool chan))
+ case eres of
+ Left (err :: SomeException) -> do
+ atomically $ writeTVar chan PQ.empty
+ forM_ threads killThread
+ throw err
+ Right res -> do
+ readTVarIO chan >>= \case
+ PQ.Empty -> return ()
+ _ -> throwIO (userError "withWorkerPool: computation exited before all jobs were handled")
+ return res
where
worker :: Ord k => Int -> TVar (PQ.MinPQueue k (Terminate PoolJob)) -> IO ()
worker idx chan = do
diff --git a/test/Main.hs b/test/Main.hs
index 2acc9f8..4bc9082 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -396,7 +396,7 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
+ annotate (ppExpr env dtermSChadSUSP)
diff outChad0 closeIsh outPrimal
diff outChadS closeIsh outPrimal
diff outChadSUS closeIsh outPrimal
@@ -562,6 +562,13 @@ tests_Compile = testGroup "Compile"
nil) $
let_ #_ (accum SAPHere nil #x #ac) $
nil
+
+ ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $
+ fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a
+
+ ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $
+ let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $
+ fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d
]
tests_AD :: TestTree