diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-25 23:56:16 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-25 23:56:16 +0100 |
commit | 7fa10a9a07c7160531baf595d1111277c17a38b2 (patch) | |
tree | 24b7263da33490d954b063926d509e1a10193687 | |
parent | 2c2b80264ae5777f0a759abb5571cbe68071c7e7 (diff) |
Compile: Emit structs in proper order
-rw-r--r-- | src/AST.hs | 30 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 40 | ||||
-rw-r--r-- | src/Compile.hs | 207 | ||||
-rw-r--r-- | src/Data.hs | 9 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 2 | ||||
-rw-r--r-- | test/Main.hs | 2 |
6 files changed, 179 insertions, 111 deletions
@@ -292,23 +292,19 @@ extOf = \case EOneHot x _ _ _ _ -> x EError x _ _ -> x --- unSNat :: SNat n -> Nat --- unSNat SZ = Z --- unSNat (SS n) = S (unSNat n) - --- unSTy :: STy t -> Ty --- unSTy = \case --- STNil -> TNil --- STPair a b -> TPair (unSTy a) (unSTy b) --- STEither a b -> TEither (unSTy a) (unSTy b) --- STMaybe t -> TMaybe (unSTy t) --- STArr n t -> TArr (unSNat n) (unSTy t) --- STScal t -> TScal (unSScalTy t) --- STAccum t -> TAccum (unSTy t) - --- unSEnv :: SList STy env -> [Ty] --- unSEnv SNil = [] --- unSEnv (SCons t l) = unSTy t : unSEnv l +unSTy :: STy t -> Ty +unSTy = \case + STNil -> TNil + STPair a b -> TPair (unSTy a) (unSTy b) + STEither a b -> TEither (unSTy a) (unSTy b) + STMaybe t -> TMaybe (unSTy t) + STArr n t -> TArr (unSNat n) (unSTy t) + STScal t -> TScal (unSScalTy t) + STAccum t -> TAccum (unSTy t) + +unSEnv :: SList STy env -> [Ty] +unSEnv SNil = [] +unSEnv (SCons t l) = unSTy t : unSEnv l unSScalTy :: SScalTy t -> ScalTy unSScalTy = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 4190f32..35c78c1 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -7,7 +7,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (ppExpr, ppTy, PrettyX(..)) where +module AST.Pretty (ppExpr, ppSTy, ppTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse) @@ -252,7 +252,7 @@ ppExpr' d val expr = case expr of ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] EZero _ t -> return $ parens $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppTy' 0 t <> ppString ")" + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppSTy' 0 t <> ppString ")" EPlus _ _ a b -> do a' <- ppExpr' 11 val a @@ -321,23 +321,29 @@ operator OExp{} = (Prefix, "exp") operator OLog{} = (Prefix, "log") operator OIDiv{} = (Infix, "`div`") -ppTy :: Int -> STy t -> String +ppSTy :: Int -> STy t -> String +ppSTy d ty = ppTy d (unSTy ty) + +ppSTy' :: Int -> STy t -> Doc q +ppSTy' d ty = ppTy' d (unSTy ty) + +ppTy :: Int -> Ty -> String ppTy d ty = render $ ppTy' d ty -ppTy' :: Int -> STy t -> Doc q -ppTy' _ STNil = ppString "1" -ppTy' d (STPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b -ppTy' d (STEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b -ppTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t -ppTy' d (STArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppTy' 11 t -ppTy' _ (STScal sty) = ppString $ case sty of - STI32 -> "i32" - STI64 -> "i64" - STF32 -> "f32" - STF64 -> "f64" - STBool -> "bool" -ppTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t +ppTy' :: Int -> Ty -> Doc q +ppTy' _ TNil = ppString "1" +ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b +ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b +ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t +ppTy' d (TArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t +ppTy' _ (TScal sty) = ppString $ case sty of + TI32 -> "i32" + TI64 -> "i64" + TF32 -> "f32" + TF64 -> "f64" + TBool -> "bool" +ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t ppString :: String -> Doc x ppString = fromString diff --git a/src/Compile.hs b/src/Compile.hs index 05d51c1..95004b8 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -6,12 +6,14 @@ module Compile where import Control.Monad.Trans.State.Strict +import Data.Bifunctor (first, second) import Data.Foldable (toList) import Data.Functor.Const import qualified Data.Functor.Product as Product import Data.List (intersperse, intercalate) import qualified Data.Map.Strict as Map -import Data.Map.Strict (Map) +import qualified Data.Set as Set +import Data.Set (Set) import AST import AST.Pretty (ppTy) @@ -52,86 +54,139 @@ printStructDecl (StructDecl name contents comment) = printStmt :: Int -> Stmt -> ShowS printStmt indent = \case - SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" + SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";" SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") - SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";" + SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" SBlock stmts -> showString "{" . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts] . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") SIf cond b1 b2 -> - showString "if (" . printCExpr cond . showString ") " + showString "if (" . printCExpr 0 cond . showString ") " . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2) SVerbatim s -> showString s -printCExpr :: CExpr -> ShowS -printCExpr = \case +-- d values: +-- * 0: top level +-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) +-- * 2-...: various operators (see precTable) +-- * 98: inside unknown operator +-- * 99: left of a field projection +-- Unlisted operators are conservatively written with full parentheses. +printCExpr :: Int -> CExpr -> ShowS +printCExpr d = \case CELit s -> showString s CEStruct name pairs -> - showString ("(" ++ name ++ "){") - . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e - | (n, e) <- pairs]) - . showString "}" - CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name) + showParen (d >= 99) $ + showString ("(" ++ name ++ "){") + . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e + | (n, e) <- pairs]) + . showString "}" + CEProj e name -> printCExpr 99 e . showString ("." ++ name) CECall n es -> - showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")" + showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" CEBinop e1 n e2 -> - showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")" + let mprec = Map.lookup n precTable + p = maybe (-1) fst mprec -- precedence of this operator + (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments + in showParen (d > p) $ + printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2 CEIf e1 e2 e3 -> - printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3 - -repTy :: STy t -> String -repTy (STScal st) = case st of - STI32 -> "int32_t" - STI64 -> "int64_t" - STF32 -> "float" - STF64 -> "double" - STBool -> "bool" + showParen (d > 0) $ + printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 + where + precTable = Map.fromList + [("||", (2, (2, 2))) + ,("&&", (3, (3, 3))) + ,("==", (4, (5, 5))) + ,("!=", (4, (5, 5))) + ,("<", (5, (6, 6))) + ,(">", (5, (6, 6))) + ,("<=", (5, (6, 6))) + ,(">=", (5, (6, 6))) + ,("+", (6, (6, 6))) + ,("-", (6, (6, 7))) + ,("*", (7, (7, 7))) + ,("/", (7, (7, 8))) + ,("%", (7, (7, 8)))] + +repTy :: Ty -> String +repTy (TScal st) = case st of + TI32 -> "int32_t" + TI64 -> "int64_t" + TF32 -> "float" + TF64 -> "double" + TBool -> "bool" repTy t = genStructName t -genStructName :: STy t -> String +repSTy :: STy t -> String +repSTy = repTy . unSTy + +genStructName :: Ty -> String genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen t - -genStruct :: STy t -> Map String StructDecl -genStruct topty = case topty of - STNil -> - Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com) - STPair a b -> - let name = genStructName (STPair a b) - in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com) - STEither a b -> - let name = genStructName (STEither a b) -- 0 -> a, 1 -> b - in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com) - STMaybe t -> - let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just - in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com) - STArr n t -> - let name = genStructName (STArr n t) - in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com) - STScal _ -> mempty - STAccum t -> - let name = genStructName (STAccum t) - in Map.singleton name (StructDecl name (repTy t ++ " a;") com) - <> genStruct t + gen :: Ty -> String + gen TNil = "n" + gen (TPair a b) = 'P' : gen a ++ gen b + gen (TEither a b) = 'E' : gen a ++ gen b + gen (TMaybe t) = 'M' : gen t + gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t + gen (TScal st) = case st of + TI32 -> "i" + TI64 -> "j" + TF32 -> "f" + TF64 -> "d" + TBool -> "b" + gen (TAccum t) = 'C' : gen t + +genStruct :: String -> Ty -> Maybe StructDecl +genStruct name topty = case topty of + TNil -> + Just $ StructDecl name "" com + TPair a b -> + Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com + TEither a b -> -- 0 -> a, 1 -> b + Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com + TMaybe t -> -- 0 -> nothing, 1 -> just + Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com + TArr n t -> + Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com + TScal _ -> + Nothing + TAccum t -> + Just $ StructDecl name (repTy t ++ " a;") com where com = ppTy 0 topty +-- State: (already-generated (skippable) struct names, the structs in declaration order) +genStructs :: Ty -> State (Set String, Bag StructDecl) () +genStructs ty = do + let name = genStructName ty + seen <- gets ((name `Set.member`) . fst) + + case (if seen then Nothing else genStruct name ty) of + Nothing -> pure () + + Just decl -> do + -- already mark this struct as generated now, so we don't generate it twice + modify (first (Set.insert name)) + + case ty of + TNil -> pure () + TPair a b -> genStructs a >> genStructs b + TEither a b -> genStructs a >> genStructs b + TMaybe t -> genStructs t + TArr _ t -> genStructs t + TScal _ -> pure () + TAccum t -> genStructs t + + modify (second (<> pure decl)) + +genAllStructs :: Foldable t => t Ty -> [StructDecl] +genAllStructs tys = toList . snd $ execState (mapM_ genStructs tys) (mempty, mempty) + data CompState = CompState - { csStructs :: Map String StructDecl + { csStructs :: Set Ty , csStmts :: Bag Stmt , csNextId :: Int } deriving (Show) @@ -156,8 +211,9 @@ scope m = do emitStruct :: STy t -> CompM String emitStruct ty = do - modify $ \s -> s { csStructs = genStruct ty <> csStructs s } - return (genStructName ty) + let ty' = unSTy ty + modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) } + return (genStructName ty') nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) @@ -166,15 +222,16 @@ compile :: SList STy env -> Ex env t -> String compile env expr = let args = nameEnv env (res, s) = runState (compile' args expr) (CompState mempty mempty 1) + structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) in ($ "") $ compose - [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s)) + [compose $ map (\sd -> printStructDecl sd . showString "\n") structs ,showString "\n" ,showString $ - repTy (typeOf expr) ++ " kernel(" ++ - intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++ + repSTy (typeOf expr) ++ " kernel(" ++ + intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repSTy t ++ " " ++ getConst n) (slistZip env args))) ++ ") {\n" ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) - ,showString (" return ") . printCExpr res . showString ";\n}\n"] + ,showString (" return ") . printCExpr 0 res . showString ";\n}\n"] compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case @@ -183,7 +240,7 @@ compile' env = \case ELet _ rhs body -> do e <- compile' env rhs var <- genName - emit $ SVarDecl True (repTy (typeOf rhs)) var e + emit $ SVarDecl True (repSTy (typeOf rhs)) var e compile' (Const var `SCons` env) body EPair _ a b -> do @@ -215,7 +272,7 @@ compile' env = \case (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (repTy (typeOf a)) retvar + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar emit $ SIf e1 (stmts2 <> pure (SAsg retvar e2)) (stmts3 <> pure (SAsg retvar e3)) @@ -229,14 +286,14 @@ compile' env = \case (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (repTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (pure (SVarDecl True (repTy t1) fieldvar + (pure (SVarDecl True (repSTy t1) fieldvar (CEProj (CELit var) "a")) <> stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl True (repTy t2) fieldvar + (pure (SVarDecl True (repSTy t2) fieldvar (CEProj (CELit var) "b")) <> stmts3 <> pure (SAsg retvar e3)))) @@ -258,12 +315,12 @@ compile' env = \case (e2, stmts2) <- scope $ compile' env a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (repTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) (stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl True (repTy (typeOf b)) fieldvar + (pure (SVarDecl True (repSTy (typeOf b)) fieldvar (CEProj (CELit var) "a")) <> stmts3 <> pure (SAsg retvar e3)))) @@ -332,7 +389,7 @@ compileOpGeneral op e1 = do let unary cop = return @(State CompState) $ CECall cop [e1] let binary cop = do name <- genName - emit $ SVarDecl True (repTy (opt1 op)) name e1 + emit $ SVarDecl True (repSTy (opt1 op)) name e1 return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") case op of OAdd _ -> binary "+" diff --git a/src/Data.hs b/src/Data.hs index 0be9046..8005737 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -77,6 +77,14 @@ fromSNat :: SNat n -> Int fromSNat SZ = 0 fromSNat (SS n) = succ (fromSNat n) +unSNat :: SNat n -> Nat +unSNat SZ = Z +unSNat (SS n) = S (unSNat n) + +fromNat :: Nat -> Int +fromNat Z = 0 +fromNat (S m) = succ (fromNat m) + class KnownNat n where knownNat :: SNat n instance KnownNat Z where knownNat = SZ instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat @@ -124,6 +132,7 @@ unsafeCoerceRefl = unsafeCoerce Refl data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t] deriving (Show, Functor, Foldable, Traversable) +-- | This instance is mostly there just for 'pure' instance Applicative Bag where pure = BOne BNone <*> _ = BNone diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 335ad1f..ac06915 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -76,7 +76,7 @@ showValue _ (STScal sty) x = case sty of STI32 -> shows x STI64 -> shows x STBool -> shows x -showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">" +showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSTy 0 t ++ ">" showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" diff --git a/test/Main.hs b/test/Main.hs index b234aa2..dde2c3d 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -212,7 +212,7 @@ adTestGen expr envGenerator = property $ do scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S - annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) + annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- annotate (ppExpr knownEnv expr) -- annotate ppdterm -- annotate ppdterm_S |