diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 /src/Compile.hs | |
parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 221 |
1 files changed, 136 insertions, 85 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 09c3ed5..5501746 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} @@ -7,8 +9,10 @@ {-# LANGUAGE TypeApplications #-} module Compile (compile, debugCSource, debugRefc, emitChecks) where +import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Maybe import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS import Data.Bifunctor (first) @@ -36,10 +40,12 @@ import qualified Prelude import Array import AST -import AST.Pretty (ppTy) +import AST.Pretty (ppSTy) +import qualified CHAD.Types as CHAD import Compile.Exec import Data import Interpreter.Rep +import qualified Util.IdGen as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -188,62 +194,59 @@ printCExpr d = \case ,("/", (7, (7, 8))) ,("%", (7, (7, 8)))] -repTy :: Ty -> String -repTy (TScal st) = case st of - TI32 -> "int32_t" - TI64 -> "int64_t" - TF32 -> "float" - TF64 -> "double" - TBool -> "uint8_t" -repTy t = genStructName t - repSTy :: STy t -> String -repSTy = repTy . unSTy - -genStructName :: Ty -> String +repSTy (STScal st) = case st of + STI32 -> "int32_t" + STI64 -> "int64_t" + STF32 -> "float" + STF64 -> "double" + STBool -> "uint8_t" +repSTy t = genStructName t + +genStructName :: STy t -> String genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. - gen :: Ty -> String - gen TNil = "n" - gen (TPair a b) = 'P' : gen a ++ gen b - gen (TEither a b) = 'E' : gen a ++ gen b - gen (TMaybe t) = 'M' : gen t - gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t - gen (TScal st) = case st of - TI32 -> "i" - TI64 -> "j" - TF32 -> "f" - TF64 -> "d" - TBool -> "b" - gen (TAccum t) = 'C' : gen t + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen t -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the -- types in the C translation. -genStruct :: String -> Ty -> [StructDecl] +genStruct :: String -> STy t -> [StructDecl] genStruct name topty = case topty of - TNil -> + STNil -> [StructDecl name "" com] - TPair a b -> - [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com] - TEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com] - TMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com] - TArr n t -> + STPair a b -> + [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + STEither a b -> -- 0 -> l, 1 -> r + [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + STMaybe t -> -- 0 -> nothing, 1 -> just + [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") "" + [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" ,StructDecl name (name ++ "_buf *buf;") com] - TScal _ -> + STScal _ -> [] - TAccum t -> - [StructDecl name (repTy t ++ " ac;") com] + STAccum t -> + [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com] where - com = ppTy 0 topty + com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) () +genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () genStructs ty = do let name = genStructName ty seen <- lift $ gets (name `Set.member`) @@ -255,19 +258,19 @@ genStructs ty = do -- twice (unnecessary because no recursive types, but y'know) lift $ modify (Set.insert name) - case ty of - TNil -> pure () - TPair a b -> genStructs a >> genStructs b - TEither a b -> genStructs a >> genStructs b - TMaybe t -> genStructs t - TArr _ t -> genStructs t - TScal _ -> pure () - TAccum t -> genStructs t + () <- case ty of + STNil -> pure () + STPair a b -> genStructs a >> genStructs b + STEither a b -> genStructs a >> genStructs b + STMaybe t -> genStructs t + STArr _ t -> genStructs t + STScal _ -> pure () + STAccum t -> genStructs (CHAD.d2 t) tell (BList (genStruct name ty)) genAllStructs :: Foldable t => t Ty -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty +genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty data CompState = CompState { csStructs :: Set Ty @@ -276,36 +279,48 @@ data CompState = CompState , csNextId :: Int } deriving (Show) -type CompM a = State CompState a +newtype CompM a = CompM (State CompState a) + deriving newtype (Functor, Applicative, Monad) + +runCompM :: CompM a -> (a, CompState) +runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) -genId :: CompM Int -genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +class Monad m => MonadNameGen m where genId :: m Int +instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +instance MonadNameGen IdGen.IdGen where genId = IdGen.genId +instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) -genName' :: String -> CompM String +genName' :: MonadNameGen m => String -> m String genName' "" = genName genName' prefix = (prefix ++) . show <$> genId -genName :: CompM String +genName :: MonadNameGen m => m String genName = genName' "x" +onlyIdGen :: IdGen.IdGen a -> CompM a +onlyIdGen m = CompM $ do + i1 <- gets csNextId + let (res, i2) = IdGen.runIdGen' i1 m + modify (\s -> s { csNextId = i2 }) + return res + emit :: Stmt -> CompM () -emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } +emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } scope :: CompM a -> CompM (a, Bag Stmt) scope m = do - stmts <- state $ \s -> (csStmts s, s { csStmts = mempty }) + stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) res <- m - innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts }) + innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) return (res, innerStmts) emitStruct :: STy t -> CompM String -emitStruct ty = do - let ty' = unSTy ty - modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } - return (genStructName ty') +emitStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) } + return (genStructName ty) emitTLD :: String -> CompM () -emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } +emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) @@ -313,7 +328,7 @@ nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" compileToString :: SList STy env -> Ex env t -> String compileToString env expr = let args = nameEnv env - (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1) + (res, s) = runCompM (compile' args expr) structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) (arg_pairs, arg_metrics) = @@ -649,7 +664,7 @@ compile' env = \case x0name <- compileAssign "foldx0" env ex0 arrname <- compileAssign "foldarr" env earr - zeroRefcountCheck "fold1i" arrname + zeroRefcountCheck (typeOf earr) "fold1i" arrname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is @@ -694,7 +709,7 @@ compile' env = \case let STArr (SS n) t = typeOf e argname <- compileAssign "sumarg" env e - zeroRefcountCheck "sum1i" argname + zeroRefcountCheck (typeOf e) "sum1i" argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, like EFold1Inner. @@ -737,7 +752,7 @@ compile' env = \case lenname <- compileAssign "replen" env elen argname <- compileAssign "reparg" env earg - zeroRefcountCheck "replicate1i" argname + zeroRefcountCheck (typeOf earg) "replicate1i" argname shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) @@ -767,7 +782,7 @@ compile' env = \case EIdx0 _ e -> do let STArr _ t = typeOf e arrname <- compileAssign "" env e - zeroRefcountCheck "idx0" arrname + zeroRefcountCheck (typeOf e) "idx0" arrname name <- genName emit $ SVarDecl True (repSTy t) name (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) @@ -779,7 +794,7 @@ compile' env = \case EIdx _ earr eidx -> do let STArr n t = typeOf earr arrname <- compileAssign "ixarr" env earr - zeroRefcountCheck "idx" arrname + zeroRefcountCheck (typeOf earr) "idx" arrname idxname <- if fromSNat n > 0 -- prevent an unused-varable warning then compileAssign "ixix" env eidx else return "" -- won't be used in this case @@ -803,7 +818,7 @@ compile' env = \case t = tTup (sreplicate n tIx) _ <- emitStruct t name <- compileAssign "" env e - zeroRefcountCheck "shape" name + zeroRefcountCheck (typeOf e) "shape" name resname <- genName emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) incrementVarAlways Decrement (typeOf e) name @@ -833,15 +848,15 @@ compile' env = \case actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 - zeroRefcountCheck "with" name1 + zeroRefcountCheck (typeOf e1) "with" name1 - mcopy <- copyForWriting t name1 + mcopy <- copyForWriting (CHAD.d2 t) name1 accname <- genName' "accum" emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)]) e2' <- compile' (Const accname `SCons` env) e2 - rettyname <- emitStruct (STPair (typeOf e2) t) + rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t)) return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")] EAccum _ t prj eidx eval eacc -> do @@ -1096,7 +1111,7 @@ shapeTupFromLitVars = \n -> go n . reverse compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do - let unary cop = return @(State CompState) $ CECall cop [e1] + let unary cop = return @CompM $ CECall cop [e1] let binary cop = do name <- genName emit $ SVarDecl True (repSTy (opt1 op)) name e1 @@ -1127,7 +1142,7 @@ compileOpGeneral op e1 = do compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr compileOpPair op e1 e2 = do - let binary cop = return @(State CompState) $ CEBinop e1 cop e2 + let binary cop = return @CompM $ CEBinop e1 cop e2 case op of OAdd _ -> binary "+" OMul _ -> binary "*" @@ -1153,7 +1168,7 @@ compileExtremum nameBase opName operator env e = do let STArr (SS n) t = typeOf e argname <- compileAssign (nameBase ++ "arg") env e - zeroRefcountCheck opName argname + zeroRefcountCheck (typeOf e) opName argname shszname <- genName' "shsz" -- This n is one less than the shape of the thing we're querying, which is @@ -1209,7 +1224,7 @@ copyForWriting topty var = case topty of _ -> do name <- genName emit $ SVarDeclUninit (repSTy topty) name - emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (stmts1 <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) @@ -1225,7 +1240,7 @@ copyForWriting topty var = case topty of Just e1' -> do name <- genName emit $ SVarDeclUninit (repSTy topty) name - emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) (stmts1 <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) @@ -1301,13 +1316,49 @@ copyForWriting topty var = case topty of STAccum _ -> error "Compile: Nested accumulators not supported" -zeroRefcountCheck :: String -> String -> CompM () -zeroRefcountCheck opname arrvar = - when emitChecks $ - emit $ SVerbatim $ - "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++ - "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ - "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }" +zeroRefcountCheck :: STy t -> String -> String -> CompM () +zeroRefcountCheck toptyp opname topvar = + when emitChecks $ do + mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) + case mstmts of + Nothing -> return () + Just stmts -> forM_ stmts emit + where + -- | If this returns 'Nothing', no statements need to be generated for this type. + go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) + go STNil _ = empty + go (STPair a b) path = do + (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) + return (s1 <> s2) + go (STEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STMaybe a) path = do + ss <- go a (path++".j") + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty + go (STArr n a) path = do + ivar <- genName' "i" + ss <- go a (path++".buf->xs["++ivar++"]") + shszname <- genName' "shsz" + let s1 = SVerbatim $ + "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ + "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++ + "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }" + let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) + let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss + return (BList [s1, s2, s3]) + go STScal{} _ = empty + go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + + combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) + combine (MaybeT a) (MaybeT b) = MaybeT $ do + x <- a + y <- b + return $ case (x, y) of + (Nothing, Nothing) -> Nothing + (Just x', Nothing) -> Just (x', mempty) + (Nothing, Just y') -> Just (mempty, y') + (Just x', Just y') -> Just (x', y') compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id |