summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-25 23:56:16 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-25 23:56:16 +0100
commit7fa10a9a07c7160531baf595d1111277c17a38b2 (patch)
tree24b7263da33490d954b063926d509e1a10193687
parent2c2b80264ae5777f0a759abb5571cbe68071c7e7 (diff)
Compile: Emit structs in proper order
-rw-r--r--src/AST.hs30
-rw-r--r--src/AST/Pretty.hs40
-rw-r--r--src/Compile.hs207
-rw-r--r--src/Data.hs9
-rw-r--r--src/Interpreter/Rep.hs2
-rw-r--r--test/Main.hs2
6 files changed, 179 insertions, 111 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 99c0681..e22d11f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -292,23 +292,19 @@ extOf = \case
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
--- unSNat :: SNat n -> Nat
--- unSNat SZ = Z
--- unSNat (SS n) = S (unSNat n)
-
--- unSTy :: STy t -> Ty
--- unSTy = \case
--- STNil -> TNil
--- STPair a b -> TPair (unSTy a) (unSTy b)
--- STEither a b -> TEither (unSTy a) (unSTy b)
--- STMaybe t -> TMaybe (unSTy t)
--- STArr n t -> TArr (unSNat n) (unSTy t)
--- STScal t -> TScal (unSScalTy t)
--- STAccum t -> TAccum (unSTy t)
-
--- unSEnv :: SList STy env -> [Ty]
--- unSEnv SNil = []
--- unSEnv (SCons t l) = unSTy t : unSEnv l
+unSTy :: STy t -> Ty
+unSTy = \case
+ STNil -> TNil
+ STPair a b -> TPair (unSTy a) (unSTy b)
+ STEither a b -> TEither (unSTy a) (unSTy b)
+ STMaybe t -> TMaybe (unSTy t)
+ STArr n t -> TArr (unSNat n) (unSTy t)
+ STScal t -> TScal (unSScalTy t)
+ STAccum t -> TAccum (unSTy t)
+
+unSEnv :: SList STy env -> [Ty]
+unSEnv SNil = []
+unSEnv (SCons t l) = unSTy t : unSEnv l
unSScalTy :: SScalTy t -> ScalTy
unSScalTy = \case
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 4190f32..35c78c1 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (ppExpr, ppTy, PrettyX(..)) where
+module AST.Pretty (ppExpr, ppSTy, ppTy, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse)
@@ -252,7 +252,7 @@ ppExpr' d val expr = case expr of
ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3']
EZero _ t -> return $ parens $
- annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppTy' 0 t <> ppString ")"
+ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "::" <+> ppSTy' 0 t <> ppString ")"
EPlus _ _ a b -> do
a' <- ppExpr' 11 val a
@@ -321,23 +321,29 @@ operator OExp{} = (Prefix, "exp")
operator OLog{} = (Prefix, "log")
operator OIDiv{} = (Infix, "`div`")
-ppTy :: Int -> STy t -> String
+ppSTy :: Int -> STy t -> String
+ppSTy d ty = ppTy d (unSTy ty)
+
+ppSTy' :: Int -> STy t -> Doc q
+ppSTy' d ty = ppTy' d (unSTy ty)
+
+ppTy :: Int -> Ty -> String
ppTy d ty = render $ ppTy' d ty
-ppTy' :: Int -> STy t -> Doc q
-ppTy' _ STNil = ppString "1"
-ppTy' d (STPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
-ppTy' d (STEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
-ppTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
-ppTy' d (STArr n t) = ppParen (d > 10) $
- ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppTy' 11 t
-ppTy' _ (STScal sty) = ppString $ case sty of
- STI32 -> "i32"
- STI64 -> "i64"
- STF32 -> "f32"
- STF64 -> "f64"
- STBool -> "bool"
-ppTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
+ppTy' :: Int -> Ty -> Doc q
+ppTy' _ TNil = ppString "1"
+ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b
+ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b
+ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t
+ppTy' d (TArr n t) = ppParen (d > 10) $
+ ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t
+ppTy' _ (TScal sty) = ppString $ case sty of
+ TI32 -> "i32"
+ TI64 -> "i64"
+ TF32 -> "f32"
+ TF64 -> "f64"
+ TBool -> "bool"
+ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t
ppString :: String -> Doc x
ppString = fromString
diff --git a/src/Compile.hs b/src/Compile.hs
index 05d51c1..95004b8 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -6,12 +6,14 @@
module Compile where
import Control.Monad.Trans.State.Strict
+import Data.Bifunctor (first, second)
import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
import Data.List (intersperse, intercalate)
import qualified Data.Map.Strict as Map
-import Data.Map.Strict (Map)
+import qualified Data.Set as Set
+import Data.Set (Set)
import AST
import AST.Pretty (ppTy)
@@ -52,86 +54,139 @@ printStructDecl (StructDecl name contents comment) =
printStmt :: Int -> Stmt -> ShowS
printStmt indent = \case
- SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";"
+ SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
- SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";"
+ SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
SBlock stmts ->
showString "{"
. compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts]
. showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
SIf cond b1 b2 ->
- showString "if (" . printCExpr cond . showString ") "
+ showString "if (" . printCExpr 0 cond . showString ") "
. printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2)
SVerbatim s -> showString s
-printCExpr :: CExpr -> ShowS
-printCExpr = \case
+-- d values:
+-- * 0: top level
+-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)
+-- * 2-...: various operators (see precTable)
+-- * 98: inside unknown operator
+-- * 99: left of a field projection
+-- Unlisted operators are conservatively written with full parentheses.
+printCExpr :: Int -> CExpr -> ShowS
+printCExpr d = \case
CELit s -> showString s
CEStruct name pairs ->
- showString ("(" ++ name ++ "){")
- . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr e
- | (n, e) <- pairs])
- . showString "}"
- CEProj e name -> showString "(" . printCExpr e . showString (")." ++ name)
+ showParen (d >= 99) $
+ showString ("(" ++ name ++ "){")
+ . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e
+ | (n, e) <- pairs])
+ . showString "}"
+ CEProj e name -> printCExpr 99 e . showString ("." ++ name)
CECall n es ->
- showString (n ++ "(") . compose (intersperse (showString ", ") (map printCExpr es)) . showString ")"
+ showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
CEBinop e1 n e2 ->
- showString "(" . printCExpr e1 . showString (") " ++ n ++ " (") . printCExpr e2 . showString ")"
+ let mprec = Map.lookup n precTable
+ p = maybe (-1) fst mprec -- precedence of this operator
+ (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments
+ in showParen (d > p) $
+ printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2
CEIf e1 e2 e3 ->
- printCExpr e1 . showString " ? " . printCExpr e2 . showString " : " . printCExpr e3
-
-repTy :: STy t -> String
-repTy (STScal st) = case st of
- STI32 -> "int32_t"
- STI64 -> "int64_t"
- STF32 -> "float"
- STF64 -> "double"
- STBool -> "bool"
+ showParen (d > 0) $
+ printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3
+ where
+ precTable = Map.fromList
+ [("||", (2, (2, 2)))
+ ,("&&", (3, (3, 3)))
+ ,("==", (4, (5, 5)))
+ ,("!=", (4, (5, 5)))
+ ,("<", (5, (6, 6)))
+ ,(">", (5, (6, 6)))
+ ,("<=", (5, (6, 6)))
+ ,(">=", (5, (6, 6)))
+ ,("+", (6, (6, 6)))
+ ,("-", (6, (6, 7)))
+ ,("*", (7, (7, 7)))
+ ,("/", (7, (7, 8)))
+ ,("%", (7, (7, 8)))]
+
+repTy :: Ty -> String
+repTy (TScal st) = case st of
+ TI32 -> "int32_t"
+ TI64 -> "int64_t"
+ TF32 -> "float"
+ TF64 -> "double"
+ TBool -> "bool"
repTy t = genStructName t
-genStructName :: STy t -> String
+repSTy :: STy t -> String
+repSTy = repTy . unSTy
+
+genStructName :: Ty -> String
genStructName = \t -> "ty_" ++ gen t where
-- all tags start with a letter, so the array mangling is unambiguous.
- gen :: STy t -> String
- gen STNil = "n"
- gen (STPair a b) = 'P' : gen a ++ gen b
- gen (STEither a b) = 'E' : gen a ++ gen b
- gen (STMaybe t) = 'M' : gen t
- gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
- gen (STScal st) = case st of
- STI32 -> "i"
- STI64 -> "j"
- STF32 -> "f"
- STF64 -> "d"
- STBool -> "b"
- gen (STAccum t) = 'C' : gen t
-
-genStruct :: STy t -> Map String StructDecl
-genStruct topty = case topty of
- STNil ->
- Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com)
- STPair a b ->
- let name = genStructName (STPair a b)
- in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com)
- STEither a b ->
- let name = genStructName (STEither a b) -- 0 -> a, 1 -> b
- in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com)
- STMaybe t ->
- let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just
- in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com)
- STArr n t ->
- let name = genStructName (STArr n t)
- in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com)
- STScal _ -> mempty
- STAccum t ->
- let name = genStructName (STAccum t)
- in Map.singleton name (StructDecl name (repTy t ++ " a;") com)
- <> genStruct t
+ gen :: Ty -> String
+ gen TNil = "n"
+ gen (TPair a b) = 'P' : gen a ++ gen b
+ gen (TEither a b) = 'E' : gen a ++ gen b
+ gen (TMaybe t) = 'M' : gen t
+ gen (TArr n t) = "A" ++ show (fromNat n) ++ gen t
+ gen (TScal st) = case st of
+ TI32 -> "i"
+ TI64 -> "j"
+ TF32 -> "f"
+ TF64 -> "d"
+ TBool -> "b"
+ gen (TAccum t) = 'C' : gen t
+
+genStruct :: String -> Ty -> Maybe StructDecl
+genStruct name topty = case topty of
+ TNil ->
+ Just $ StructDecl name "" com
+ TPair a b ->
+ Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com
+ TEither a b -> -- 0 -> a, 1 -> b
+ Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com
+ TMaybe t -> -- 0 -> nothing, 1 -> just
+ Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com
+ TArr n t ->
+ Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com
+ TScal _ ->
+ Nothing
+ TAccum t ->
+ Just $ StructDecl name (repTy t ++ " a;") com
where
com = ppTy 0 topty
+-- State: (already-generated (skippable) struct names, the structs in declaration order)
+genStructs :: Ty -> State (Set String, Bag StructDecl) ()
+genStructs ty = do
+ let name = genStructName ty
+ seen <- gets ((name `Set.member`) . fst)
+
+ case (if seen then Nothing else genStruct name ty) of
+ Nothing -> pure ()
+
+ Just decl -> do
+ -- already mark this struct as generated now, so we don't generate it twice
+ modify (first (Set.insert name))
+
+ case ty of
+ TNil -> pure ()
+ TPair a b -> genStructs a >> genStructs b
+ TEither a b -> genStructs a >> genStructs b
+ TMaybe t -> genStructs t
+ TArr _ t -> genStructs t
+ TScal _ -> pure ()
+ TAccum t -> genStructs t
+
+ modify (second (<> pure decl))
+
+genAllStructs :: Foldable t => t Ty -> [StructDecl]
+genAllStructs tys = toList . snd $ execState (mapM_ genStructs tys) (mempty, mempty)
+
data CompState = CompState
- { csStructs :: Map String StructDecl
+ { csStructs :: Set Ty
, csStmts :: Bag Stmt
, csNextId :: Int }
deriving (Show)
@@ -156,8 +211,9 @@ scope m = do
emitStruct :: STy t -> CompM String
emitStruct ty = do
- modify $ \s -> s { csStructs = genStruct ty <> csStructs s }
- return (genStructName ty)
+ let ty' = unSTy ty
+ modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) }
+ return (genStructName ty')
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
@@ -166,15 +222,16 @@ compile :: SList STy env -> Ex env t -> String
compile env expr =
let args = nameEnv env
(res, s) = runState (compile' args expr) (CompState mempty mempty 1)
+ structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
in ($ "") $ compose
- [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s))
+ [compose $ map (\sd -> printStructDecl sd . showString "\n") structs
,showString "\n"
,showString $
- repTy (typeOf expr) ++ " kernel(" ++
- intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++
+ repSTy (typeOf expr) ++ " kernel(" ++
+ intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repSTy t ++ " " ++ getConst n) (slistZip env args))) ++
") {\n"
,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s))
- ,showString (" return ") . printCExpr res . showString ";\n}\n"]
+ ,showString (" return ") . printCExpr 0 res . showString ";\n}\n"]
compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
compile' env = \case
@@ -183,7 +240,7 @@ compile' env = \case
ELet _ rhs body -> do
e <- compile' env rhs
var <- genName
- emit $ SVarDecl True (repTy (typeOf rhs)) var e
+ emit $ SVarDecl True (repSTy (typeOf rhs)) var e
compile' (Const var `SCons` env) body
EPair _ a b -> do
@@ -215,7 +272,7 @@ compile' env = \case
(e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you
(e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b
retvar <- genName
- emit $ SVarDeclUninit (repTy (typeOf a)) retvar
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
emit $ SIf e1
(stmts2 <> pure (SAsg retvar e2))
(stmts3 <> pure (SAsg retvar e3))
@@ -229,14 +286,14 @@ compile' env = \case
(e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a
(e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
retvar <- genName
- emit $ SVarDeclUninit (repTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1)
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (pure (SVarDecl True (repTy t1) fieldvar
+ (pure (SVarDecl True (repSTy t1) fieldvar
(CEProj (CELit var) "a"))
<> stmts2
<> pure (SAsg retvar e2))
- (pure (SVarDecl True (repTy t2) fieldvar
+ (pure (SVarDecl True (repSTy t2) fieldvar
(CEProj (CELit var) "b"))
<> stmts3
<> pure (SAsg retvar e3))))
@@ -258,12 +315,12 @@ compile' env = \case
(e2, stmts2) <- scope $ compile' env a
(e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
retvar <- genName
- emit $ SVarDeclUninit (repTy (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1)
+ emit $ SVarDeclUninit (repSTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
(stmts2
<> pure (SAsg retvar e2))
- (pure (SVarDecl True (repTy (typeOf b)) fieldvar
+ (pure (SVarDecl True (repSTy (typeOf b)) fieldvar
(CEProj (CELit var) "a"))
<> stmts3
<> pure (SAsg retvar e3))))
@@ -332,7 +389,7 @@ compileOpGeneral op e1 = do
let unary cop = return @(State CompState) $ CECall cop [e1]
let binary cop = do
name <- genName
- emit $ SVarDecl True (repTy (opt1 op)) name e1
+ emit $ SVarDecl True (repSTy (opt1 op)) name e1
return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
case op of
OAdd _ -> binary "+"
diff --git a/src/Data.hs b/src/Data.hs
index 0be9046..8005737 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -77,6 +77,14 @@ fromSNat :: SNat n -> Int
fromSNat SZ = 0
fromSNat (SS n) = succ (fromSNat n)
+unSNat :: SNat n -> Nat
+unSNat SZ = Z
+unSNat (SS n) = S (unSNat n)
+
+fromNat :: Nat -> Int
+fromNat Z = 0
+fromNat (S m) = succ (fromNat m)
+
class KnownNat n where knownNat :: SNat n
instance KnownNat Z where knownNat = SZ
instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
@@ -124,6 +132,7 @@ unsafeCoerceRefl = unsafeCoerce Refl
data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t]
deriving (Show, Functor, Foldable, Traversable)
+-- | This instance is mostly there just for 'pure'
instance Applicative Bag where
pure = BOne
BNone <*> _ = BNone
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 335ad1f..ac06915 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -76,7 +76,7 @@ showValue _ (STScal sty) x = case sty of
STI32 -> shows x
STI64 -> shows x
STBool -> shows x
-showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppTy 0 t ++ ">"
+showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSTy 0 t ++ ">"
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
diff --git a/test/Main.hs b/test/Main.hs
index b234aa2..dde2c3d 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -212,7 +212,7 @@ adTestGen expr envGenerator = property $ do
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
- annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
+ annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
-- annotate (ppExpr knownEnv expr)
-- annotate ppdterm
-- annotate ppdterm_S