summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
commita00234388d1b4e14481067d030bf90031258b756 (patch)
tree501b6778fc5779ce220aba1e22f56ae60f68d970 /src/Compile.hs
parent7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff)
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs221
1 files changed, 136 insertions, 85 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 09c3ed5..5501746 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,5 +1,7 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
@@ -7,8 +9,10 @@
{-# LANGUAGE TypeApplications #-}
module Compile (compile, debugCSource, debugRefc, emitChecks) where
+import Control.Applicative (empty)
import Control.Monad (forM_, when, replicateM)
import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Maybe
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
import Data.Bifunctor (first)
@@ -36,10 +40,12 @@ import qualified Prelude
import Array
import AST
-import AST.Pretty (ppTy)
+import AST.Pretty (ppSTy)
+import qualified CHAD.Types as CHAD
import Compile.Exec
import Data
import Interpreter.Rep
+import qualified Util.IdGen as IdGen
-- In shape and index arrays, the innermost dimension is on the right (last index).
@@ -188,62 +194,59 @@ printCExpr d = \case
,("/", (7, (7, 8)))
,("%", (7, (7, 8)))]
-repTy :: Ty -> String
-repTy (TScal st) = case st of
- TI32 -> "int32_t"
- TI64 -> "int64_t"
- TF32 -> "float"
- TF64 -> "double"
- TBool -> "uint8_t"
-repTy t = genStructName t
-
repSTy :: STy t -> String
-repSTy = repTy . unSTy
-
-genStructName :: Ty -> String
+repSTy (STScal st) = case st of
+ STI32 -> "int32_t"
+ STI64 -> "int64_t"
+ STF32 -> "float"
+ STF64 -> "double"
+ STBool -> "uint8_t"
+repSTy t = genStructName t
+
+genStructName :: STy t -> String
genStructName = \t -> "ty_" ++ gen t where
-- all tags start with a letter, so the array mangling is unambiguous.
- gen :: Ty -> String
- gen TNil = "n"
- gen (TPair a b) = 'P' : gen a ++ gen b
- gen (TEither a b) = 'E' : gen a ++ gen b
- gen (TMaybe t) = 'M' : gen t
- gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t
- gen (TScal st) = case st of
- TI32 -> "i"
- TI64 -> "j"
- TF32 -> "f"
- TF64 -> "d"
- TBool -> "b"
- gen (TAccum t) = 'C' : gen t
+ gen :: STy t -> String
+ gen STNil = "n"
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
+ gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
+ gen (STScal st) = case st of
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
+ STBool -> "b"
+ gen (STAccum t) = 'C' : gen t
-- | This function generates the actual struct declarations for each of the
-- types in our language. It thus implicitly "documents" the layout of the
-- types in the C translation.
-genStruct :: String -> Ty -> [StructDecl]
+genStruct :: String -> STy t -> [StructDecl]
genStruct name topty = case topty of
- TNil ->
+ STNil ->
[StructDecl name "" com]
- TPair a b ->
- [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com]
- TEither a b -> -- 0 -> l, 1 -> r
- [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " l; " ++ repTy b ++ " r; };") com]
- TMaybe t -> -- 0 -> nothing, 1 -> just
- [StructDecl name ("uint8_t tag; " ++ repTy t ++ " j;") com]
- TArr n t ->
+ STPair a b ->
+ [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com]
+ STEither a b -> -- 0 -> l, 1 -> r
+ [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com]
+ STMaybe t -> -- 0 -> nothing, 1 -> just
+ [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com]
+ STArr n t ->
-- The buffer is trailed by a VLA for the actual array data.
- [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " xs[];") ""
+ [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") ""
,StructDecl name (name ++ "_buf *buf;") com]
- TScal _ ->
+ STScal _ ->
[]
- TAccum t ->
- [StructDecl name (repTy t ++ " ac;") com]
+ STAccum t ->
+ [StructDecl name (repSTy (CHAD.d2 t) ++ " ac;") com]
where
- com = ppTy 0 topty
+ com = ppSTy 0 topty
-- State: already-generated (skippable) struct names
-- Writer: the structs in declaration order
-genStructs :: Ty -> WriterT (Bag StructDecl) (State (Set String)) ()
+genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) ()
genStructs ty = do
let name = genStructName ty
seen <- lift $ gets (name `Set.member`)
@@ -255,19 +258,19 @@ genStructs ty = do
-- twice (unnecessary because no recursive types, but y'know)
lift $ modify (Set.insert name)
- case ty of
- TNil -> pure ()
- TPair a b -> genStructs a >> genStructs b
- TEither a b -> genStructs a >> genStructs b
- TMaybe t -> genStructs t
- TArr _ t -> genStructs t
- TScal _ -> pure ()
- TAccum t -> genStructs t
+ () <- case ty of
+ STNil -> pure ()
+ STPair a b -> genStructs a >> genStructs b
+ STEither a b -> genStructs a >> genStructs b
+ STMaybe t -> genStructs t
+ STArr _ t -> genStructs t
+ STScal _ -> pure ()
+ STAccum t -> genStructs (CHAD.d2 t)
tell (BList (genStruct name ty))
genAllStructs :: Foldable t => t Ty -> [StructDecl]
-genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty
+genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\t -> case reSTy t of Some t' -> genStructs t') tys)) mempty
data CompState = CompState
{ csStructs :: Set Ty
@@ -276,36 +279,48 @@ data CompState = CompState
, csNextId :: Int }
deriving (Show)
-type CompM a = State CompState a
+newtype CompM a = CompM (State CompState a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runCompM :: CompM a -> (a, CompState)
+runCompM (CompM m) = runState m (CompState mempty mempty mempty 1)
-genId :: CompM Int
-genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+class Monad m => MonadNameGen m where genId :: m Int
+instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 })
+instance MonadNameGen IdGen.IdGen where genId = IdGen.genId
+instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId)
-genName' :: String -> CompM String
+genName' :: MonadNameGen m => String -> m String
genName' "" = genName
genName' prefix = (prefix ++) . show <$> genId
-genName :: CompM String
+genName :: MonadNameGen m => m String
genName = genName' "x"
+onlyIdGen :: IdGen.IdGen a -> CompM a
+onlyIdGen m = CompM $ do
+ i1 <- gets csNextId
+ let (res, i2) = IdGen.runIdGen' i1 m
+ modify (\s -> s { csNextId = i2 })
+ return res
+
emit :: Stmt -> CompM ()
-emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt }
+emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt }
scope :: CompM a -> CompM (a, Bag Stmt)
scope m = do
- stmts <- state $ \s -> (csStmts s, s { csStmts = mempty })
+ stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty })
res <- m
- innerStmts <- state $ \s -> (csStmts s, s { csStmts = stmts })
+ innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts })
return (res, innerStmts)
emitStruct :: STy t -> CompM String
-emitStruct ty = do
- let ty' = unSTy ty
- modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) }
- return (genStructName ty')
+emitStruct ty = CompM $ do
+ modify $ \s -> s { csStructs = Set.insert (unSTy ty) (csStructs s) }
+ return (genStructName ty)
emitTLD :: String -> CompM ()
-emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
+emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl }
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
@@ -313,7 +328,7 @@ nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg"
compileToString :: SList STy env -> Ex env t -> String
compileToString env expr =
let args = nameEnv env
- (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1)
+ (res, s) = runCompM (compile' args expr)
structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
(arg_pairs, arg_metrics) =
@@ -649,7 +664,7 @@ compile' env = \case
x0name <- compileAssign "foldx0" env ex0
arrname <- compileAssign "foldarr" env earr
- zeroRefcountCheck "fold1i" arrname
+ zeroRefcountCheck (typeOf earr) "fold1i" arrname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
@@ -694,7 +709,7 @@ compile' env = \case
let STArr (SS n) t = typeOf e
argname <- compileAssign "sumarg" env e
- zeroRefcountCheck "sum1i" argname
+ zeroRefcountCheck (typeOf e) "sum1i" argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, like EFold1Inner.
@@ -737,7 +752,7 @@ compile' env = \case
lenname <- compileAssign "replen" env elen
argname <- compileAssign "reparg" env earg
- zeroRefcountCheck "replicate1i" argname
+ zeroRefcountCheck (typeOf earg) "replicate1i" argname
shszname <- genName' "shsz"
emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
@@ -767,7 +782,7 @@ compile' env = \case
EIdx0 _ e -> do
let STArr _ t = typeOf e
arrname <- compileAssign "" env e
- zeroRefcountCheck "idx0" arrname
+ zeroRefcountCheck (typeOf e) "idx0" arrname
name <- genName
emit $ SVarDecl True (repSTy t) name
(CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0"))
@@ -779,7 +794,7 @@ compile' env = \case
EIdx _ earr eidx -> do
let STArr n t = typeOf earr
arrname <- compileAssign "ixarr" env earr
- zeroRefcountCheck "idx" arrname
+ zeroRefcountCheck (typeOf earr) "idx" arrname
idxname <- if fromSNat n > 0 -- prevent an unused-varable warning
then compileAssign "ixix" env eidx
else return "" -- won't be used in this case
@@ -803,7 +818,7 @@ compile' env = \case
t = tTup (sreplicate n tIx)
_ <- emitStruct t
name <- compileAssign "" env e
- zeroRefcountCheck "shape" name
+ zeroRefcountCheck (typeOf e) "shape" name
resname <- genName
emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name)
incrementVarAlways Decrement (typeOf e) name
@@ -833,15 +848,15 @@ compile' env = \case
actyname <- emitStruct (STAccum t)
name1 <- compileAssign "" env e1
- zeroRefcountCheck "with" name1
+ zeroRefcountCheck (typeOf e1) "with" name1
- mcopy <- copyForWriting t name1
+ mcopy <- copyForWriting (CHAD.d2 t) name1
accname <- genName' "accum"
emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])
e2' <- compile' (Const accname `SCons` env) e2
- rettyname <- emitStruct (STPair (typeOf e2) t)
+ rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t))
return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")]
EAccum _ t prj eidx eval eacc -> do
@@ -1096,7 +1111,7 @@ shapeTupFromLitVars = \n -> go n . reverse
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
- let unary cop = return @(State CompState) $ CECall cop [e1]
+ let unary cop = return @CompM $ CECall cop [e1]
let binary cop = do
name <- genName
emit $ SVarDecl True (repSTy (opt1 op)) name e1
@@ -1127,7 +1142,7 @@ compileOpGeneral op e1 = do
compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr
compileOpPair op e1 e2 = do
- let binary cop = return @(State CompState) $ CEBinop e1 cop e2
+ let binary cop = return @CompM $ CEBinop e1 cop e2
case op of
OAdd _ -> binary "+"
OMul _ -> binary "*"
@@ -1153,7 +1168,7 @@ compileExtremum nameBase opName operator env e = do
let STArr (SS n) t = typeOf e
argname <- compileAssign (nameBase ++ "arg") env e
- zeroRefcountCheck opName argname
+ zeroRefcountCheck (typeOf e) opName argname
shszname <- genName' "shsz"
-- This n is one less than the shape of the thing we're querying, which is
@@ -1209,7 +1224,7 @@ copyForWriting topty var = case topty of
_ -> do
name <- genName
emit $ SVarDeclUninit (repSTy topty) name
- emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
(stmts1
<> pure (SAsg name (CEStruct (repSTy topty)
[("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
@@ -1225,7 +1240,7 @@ copyForWriting topty var = case topty of
Just e1' -> do
name <- genName
emit $ SVarDeclUninit (repSTy topty) name
- emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
+ emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0"))
(pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))
(stmts1
<> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')])))
@@ -1301,13 +1316,49 @@ copyForWriting topty var = case topty of
STAccum _ -> error "Compile: Nested accumulators not supported"
-zeroRefcountCheck :: String -> String -> CompM ()
-zeroRefcountCheck opname arrvar =
- when emitChecks $
- emit $ SVerbatim $
- "if (__builtin_expect(" ++ arrvar ++ ".buf->refc == 0, 0)) { " ++
- "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
- "%p with refc=0\\n\", " ++ arrvar ++ ".buf); abort(); }"
+zeroRefcountCheck :: STy t -> String -> String -> CompM ()
+zeroRefcountCheck toptyp opname topvar =
+ when emitChecks $ do
+ mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar)
+ case mstmts of
+ Nothing -> return ()
+ Just stmts -> forM_ stmts emit
+ where
+ -- | If this returns 'Nothing', no statements need to be generated for this type.
+ go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt)
+ go STNil _ = empty
+ go (STPair a b) path = do
+ (s1, s2) <- combine (go a (path++".a")) (go b (path++".b"))
+ return (s1 <> s2)
+ go (STEither a b) path = do
+ (s1, s2) <- combine (go a (path++".l")) (go b (path++".r"))
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2
+ go (STMaybe a) path = do
+ ss <- go a (path++".j")
+ return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty
+ go (STArr n a) path = do
+ ivar <- genName' "i"
+ ss <- go a (path++".buf->xs["++ivar++"]")
+ shszname <- genName' "shsz"
+ let s1 = SVerbatim $
+ "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++
+ "fprintf(stderr, \"[chad-kernel] CHECK: '" ++ opname ++ "' got array " ++
+ "%p with refc=0\\n\", " ++ path ++ ".buf); abort(); }"
+ let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path)
+ let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss
+ return (BList [s1, s2, s3])
+ go STScal{} _ = empty
+ go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator"
+
+ combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b)
+ combine (MaybeT a) (MaybeT b) = MaybeT $ do
+ x <- a
+ y <- b
+ return $ case (x, y) of
+ (Nothing, Nothing) -> Nothing
+ (Just x', Nothing) -> Just (x', mempty)
+ (Nothing, Just y') -> Just (mempty, y')
+ (Just x', Just y') -> Just (x', y')
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id