diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-02-25 23:56:16 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-25 23:56:16 +0100 | 
| commit | 7fa10a9a07c7160531baf595d1111277c17a38b2 (patch) | |
| tree | 24b7263da33490d954b063926d509e1a10193687 /src/Compile.hs | |
| parent | 2c2b80264ae5777f0a759abb5571cbe68071c7e7 (diff) | |
Compile: Emit structs in proper order
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 203 | 
1 files changed, 130 insertions, 73 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 05d51c1..95004b8 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -6,12 +6,14 @@  module Compile where  import Control.Monad.Trans.State.Strict +import Data.Bifunctor (first, second)  import Data.Foldable (toList)  import Data.Functor.Const  import qualified Data.Functor.Product as Product  import Data.List (intersperse, intercalate)  import qualified Data.Map.Strict as Map -import Data.Map.Strict (Map) +import qualified Data.Set as Set +import Data.Set (Set)  import AST  import AST.Pretty (ppTy) @@ -52,86 +54,139 @@ printStructDecl (StructDecl name contents comment) =  printStmt :: Int -> Stmt -> ShowS  printStmt indent = \case -  SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" +  SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";"    SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") -  SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";" +  SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"    SBlock stmts ->      showString "{"      . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts]      . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")    SIf cond b1 b2 -> -    showString "if (" . printCExpr cond . showString ") " +    showString "if (" . printCExpr 0 cond . showString ") "      . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2)    SVerbatim s -> showString s -printCExpr :: CExpr -> ShowS -printCExpr = \case +-- d values: +-- * 0: top level +-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) +-- * 2-...: various operators (see precTable) +-- * 98: inside unknown operator +-- * 99: left of a field projection +-- Unlisted operators are conservatively written with full parentheses. +printCExpr :: Int -> CExpr -> ShowS +printCExpr d = \case    CELit s -> showString s    CEStruct name pairs -> -    showString ("(" ++ name ++ "){") -    . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e -                                             | (n, e) <- pairs]) -    . showString "}" -  CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name) +    showParen (d >= 99) $ +      showString ("(" ++ name ++ "){") +      . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e +                                               | (n, e) <- pairs]) +      . showString "}" +  CEProj e name -> printCExpr 99 e . showString ("." ++ name)    CECall n es -> -    showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")" +    showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"    CEBinop e1 n e2 -> -    showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")" +    let mprec = Map.lookup n precTable +        p = maybe (-1) fst mprec  -- precedence of this operator +        (d1, d2) = maybe (98, 98) snd mprec  -- precedences for the arguments +    in showParen (d > p) $ +         printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2    CEIf e1 e2 e3 -> -    printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3 +    showParen (d > 0) $ +      printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 +  where +    precTable = Map.fromList +      [("||", (2, (2, 2))) +      ,("&&", (3, (3, 3))) +      ,("==", (4, (5, 5))) +      ,("!=", (4, (5, 5))) +      ,("<", (5, (6, 6))) +      ,(">", (5, (6, 6))) +      ,("<=", (5, (6, 6))) +      ,(">=", (5, (6, 6))) +      ,("+", (6, (6, 6))) +      ,("-", (6, (6, 7))) +      ,("*", (7, (7, 7))) +      ,("/", (7, (7, 8))) +      ,("%", (7, (7, 8)))] -repTy :: STy t -> String -repTy (STScal st) = case st of -  STI32 -> "int32_t" -  STI64 -> "int64_t" -  STF32 -> "float" -  STF64 -> "double" -  STBool -> "bool" +repTy :: Ty -> String +repTy (TScal st) = case st of +  TI32 -> "int32_t" +  TI64 -> "int64_t" +  TF32 -> "float" +  TF64 -> "double" +  TBool -> "bool"  repTy t = genStructName t -genStructName :: STy t -> String +repSTy :: STy t -> String +repSTy = repTy . unSTy + +genStructName :: Ty -> String  genStructName = \t -> "ty_" ++ gen t where    -- all tags start with a letter, so the array mangling is unambiguous. -  gen :: STy t -> String -  gen STNil = "n" -  gen (STPair a b) = 'P' : gen a ++ gen b -  gen (STEither a b) = 'E' : gen a ++ gen b -  gen (STMaybe t) = 'M' : gen t -  gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t -  gen (STScal st) = case st of -    STI32 -> "i" -    STI64 -> "j" -    STF32 -> "f" -    STF64 -> "d" -    STBool -> "b" -  gen (STAccum t) = 'C' : gen t +  gen :: Ty -> String +  gen TNil = "n" +  gen (TPair a b) = 'P' : gen a ++ gen b +  gen (TEither a b) = 'E' : gen a ++ gen b +  gen (TMaybe t) = 'M' : gen t +  gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t +  gen (TScal st) = case st of +    TI32 -> "i" +    TI64 -> "j" +    TF32 -> "f" +    TF64 -> "d" +    TBool -> "b" +  gen (TAccum t) = 'C' : gen t -genStruct :: STy t -> Map String StructDecl -genStruct topty = case topty of -  STNil -> -    Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com) -  STPair a b -> -    let name = genStructName (STPair a b) -    in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com) -  STEither a b -> -    let name = genStructName (STEither a b)  -- 0 -> a, 1 -> b -    in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com) -  STMaybe t -> -    let name = genStructName (STMaybe t)  -- 0 -> nothing, 1 -> just -    in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com) -  STArr n t -> -    let name = genStructName (STArr n t) -    in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com) -  STScal _ -> mempty -  STAccum t -> -    let name = genStructName (STAccum t) -    in Map.singleton name (StructDecl name (repTy t ++ " a;") com) -         <> genStruct t +genStruct :: String -> Ty -> Maybe StructDecl +genStruct name topty = case topty of +  TNil -> +    Just $ StructDecl name "" com +  TPair a b -> +    Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com +  TEither a b ->  -- 0 -> a, 1 -> b +    Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com +  TMaybe t ->  -- 0 -> nothing, 1 -> just +    Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com +  TArr n t -> +    Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com +  TScal _ -> +    Nothing +  TAccum t -> +    Just $ StructDecl name (repTy t ++ " a;") com    where      com = ppTy 0 topty +-- State: (already-generated (skippable) struct names, the structs in declaration order) +genStructs :: Ty -> State (Set String, Bag StructDecl) () +genStructs ty = do +  let name = genStructName ty +  seen <- gets ((name `Set.member`) . fst) + +  case (if seen then Nothing else genStruct name ty) of +    Nothing -> pure () + +    Just decl -> do +      -- already mark this struct as generated now, so we don't generate it twice +      modify (first (Set.insert name)) + +      case ty of +        TNil -> pure () +        TPair a b -> genStructs a >> genStructs b +        TEither a b -> genStructs a >> genStructs b +        TMaybe t -> genStructs t +        TArr _ t -> genStructs t +        TScal _ -> pure () +        TAccum t -> genStructs t + +      modify (second (<> pure decl)) + +genAllStructs :: Foldable t => t Ty -> [StructDecl] +genAllStructs tys = toList . snd $ execState (mapM_ genStructs tys) (mempty, mempty) +  data CompState = CompState -  { csStructs :: Map String StructDecl +  { csStructs :: Set Ty    , csStmts :: Bag Stmt    , csNextId :: Int }    deriving (Show) @@ -156,8 +211,9 @@ scope m = do  emitStruct :: STy t -> CompM String  emitStruct ty = do -  modify $ \s -> s { csStructs = genStruct ty <> csStructs s } -  return (genStructName ty) +  let ty' = unSTy ty +  modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } +  return (genStructName ty')  nameEnv :: SList f env -> SList (Const String) env  nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) @@ -166,15 +222,16 @@ compile :: SList STy env -> Ex env t -> String  compile env expr =    let args = nameEnv env        (res, s) = runState (compile' args expr) (CompState mempty mempty 1) +      structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))    in ($ "") $ compose -       [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s)) +       [compose $ map (\sd -> printStructDecl sd . showString "\n") structs         ,showString "\n"         ,showString $ -          repTy (typeOf expr) ++ " kernel(" ++ -            intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++ +          repSTy (typeOf expr) ++ " kernel(" ++ +            intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repSTy t ++ " " ++ getConst n) (slistZip env args))) ++              ") {\n"         ,compose $ map (\st -> showString "  " . printStmt 1 st . showString "\n") (toList (csStmts s)) -       ,showString ("  return ") . printCExpr res . showString ";\n}\n"] +       ,showString ("  return ") . printCExpr 0 res . showString ";\n}\n"]  compile' :: SList (Const String) env -> Ex env t -> CompM CExpr  compile' env = \case @@ -183,7 +240,7 @@ compile' env = \case    ELet _ rhs body -> do      e <- compile' env rhs      var <- genName -    emit $ SVarDecl True (repTy (typeOf rhs)) var e +    emit $ SVarDecl True (repSTy (typeOf rhs)) var e      compile' (Const var `SCons` env) body    EPair _ a b -> do @@ -215,7 +272,7 @@ compile' env = \case      (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a  -- don't access that nil, stupid you      (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b      retvar <- genName -    emit $ SVarDeclUninit (repTy (typeOf a)) retvar +    emit $ SVarDeclUninit (repSTy (typeOf a)) retvar      emit $ SIf e1               (stmts2 <> pure (SAsg retvar e2))               (stmts3 <> pure (SAsg retvar e3)) @@ -229,14 +286,14 @@ compile' env = \case      (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a      (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b      retvar <- genName -    emit $ SVarDeclUninit (repTy (typeOf a)) retvar -    emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) +    emit $ SVarDeclUninit (repSTy (typeOf a)) retvar +    emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)                  <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) -                           (pure (SVarDecl True (repTy t1) fieldvar +                           (pure (SVarDecl True (repSTy t1) fieldvar                                             (CEProj (CELit var) "a"))                              <> stmts2                              <> pure (SAsg retvar e2)) -                           (pure (SVarDecl True (repTy t2) fieldvar +                           (pure (SVarDecl True (repSTy t2) fieldvar                                             (CEProj (CELit var) "b"))                              <> stmts3                              <> pure (SAsg retvar e3)))) @@ -258,12 +315,12 @@ compile' env = \case      (e2, stmts2) <- scope $ compile' env a      (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b      retvar <- genName -    emit $ SVarDeclUninit (repTy (typeOf a)) retvar -    emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) +    emit $ SVarDeclUninit (repSTy (typeOf a)) retvar +    emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)                  <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))                             (stmts2                              <> pure (SAsg retvar e2)) -                           (pure (SVarDecl True (repTy (typeOf b)) fieldvar +                           (pure (SVarDecl True (repSTy (typeOf b)) fieldvar                                             (CEProj (CELit var) "a"))                              <> stmts3                              <> pure (SAsg retvar e3)))) @@ -332,7 +389,7 @@ compileOpGeneral op e1 = do    let unary cop = return @(State CompState) $ CECall cop [e1]    let binary cop = do          name <- genName -        emit $ SVarDecl True (repTy (opt1 op)) name e1 +        emit $ SVarDecl True (repSTy (opt1 op)) name e1          return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")    case op of      OAdd _ -> binary "+" | 
