summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-12-12 22:40:17 +0100
committerTom Smeding <tom@tomsmeding.com>2024-12-12 22:40:17 +0100
commit27e5422c541623fbee36f2eedf37bb3d2ca3d14c (patch)
treeaa571776e2afed4131b90787a7279c755ab9647c
parentf323076ddf6fbea9f7a1a4dfeec98629459c49fc (diff)
Improve Compile a little (still only scalars)
-rw-r--r--src/Compile.hs112
-rw-r--r--src/Data.hs9
2 files changed, 64 insertions, 57 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 83c25c3..0db0d0f 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,12 +1,14 @@
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
module Compile where
import Control.Monad.Trans.State.Strict
import Data.Foldable (toList)
import Data.Functor.Const
+import qualified Data.Functor.Product as Product
import Data.List (intersperse, intercalate)
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
@@ -26,7 +28,7 @@ data StructDecl = StructDecl
deriving (Show)
data Stmt
- = SVarDecl String String CExpr -- ^ type, variable name, right-hand side
+ = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side
| SVarDeclUninit String String -- ^ type, variable name (no initialiser)
| SAsg String CExpr -- ^ variable name, right-hand side
| SBlock [Stmt]
@@ -50,7 +52,7 @@ printStructDecl (StructDecl name contents comment) =
printStmt :: Int -> Stmt -> ShowS
printStmt indent = \case
- SVarDecl typ name rhs -> showString (typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";"
+ SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";"
SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";"
SBlock stmts ->
@@ -92,15 +94,15 @@ genStructName = \t -> "ty_" ++ gen t where
-- all tags start with a letter, so the array mangling is unambiguous.
gen :: STy t -> String
gen STNil = "n"
- gen (STPair a b) = 'p' : gen a ++ gen b
- gen (STEither a b) = 'e' : gen a ++ gen b
- gen (STMaybe t) = 'm' : gen t
+ gen (STPair a b) = 'P' : gen a ++ gen b
+ gen (STEither a b) = 'E' : gen a ++ gen b
+ gen (STMaybe t) = 'M' : gen t
gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t
gen (STScal st) = case st of
- STI32 -> "i4"
- STI64 -> "i8"
- STF32 -> "f4"
- STF64 -> "f8"
+ STI32 -> "i"
+ STI64 -> "j"
+ STF32 -> "f"
+ STF64 -> "d"
STBool -> "b"
gen (STAccum t) = 'C' : gen t
@@ -110,20 +112,20 @@ genStruct topty = case topty of
Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com)
STPair a b ->
let name = genStructName (STPair a b)
- in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;") com)
+ in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com)
STEither a b ->
let name = genStructName (STEither a b) -- 0 -> a, 1 -> b
- in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };") com)
+ in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com)
STMaybe t ->
let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just
- in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;") com)
+ in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com)
STArr n t ->
let name = genStructName (STArr n t)
- in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;") com)
+ in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com)
STScal _ -> mempty
STAccum t ->
let name = genStructName (STAccum t)
- in Map.singleton name (StructDecl name (genStructName t ++ " a;") com)
+ in Map.singleton name (StructDecl name (repTy t ++ " a;") com)
<> genStruct t
where
com = ppTy 0 topty
@@ -157,13 +159,20 @@ emitStruct ty = do
modify $ \s -> s { csStructs = genStruct ty <> csStructs s }
return (genStructName ty)
-compile :: SList (Const String) env -> Ex env t -> String
+nameEnv :: SList f env -> SList (Const String) env
+nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
+
+compile :: SList STy env -> Ex env t -> String
compile env expr =
- let (res, s) = runState (compile' env expr) (CompState mempty mempty 1)
+ let args = nameEnv env
+ (res, s) = runState (compile' args expr) (CompState mempty mempty 1)
in ($ "") $ compose
[compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s))
,showString "\n"
- ,showString (genStructName (typeOf expr) ++ " kernel(" ++ intercalate ", " (reverse (unSList getConst env)) ++ ") {\n")
+ ,showString $
+ repTy (typeOf expr) ++ " kernel(" ++
+ intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++
+ ") {\n"
,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s))
,showString (" return ") . printCExpr res . showString ";\n}\n"]
@@ -174,7 +183,7 @@ compile' env = \case
ELet _ rhs body -> do
e <- compile' env rhs
var <- genName
- emit $ SVarDecl (genStructName (typeOf rhs)) var e
+ emit $ SVarDecl True (repTy (typeOf rhs)) var e
compile' (Const var `SCons` env) body
EPair _ a b -> do
@@ -201,6 +210,17 @@ compile' env = \case
e2 <- compile' env e
return $ CEStruct name [("tag", CELit "1"), ("b", e2)]
+ ECase _ (EOp _ OIf e) a b -> do
+ e1 <- compile' env e
+ (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you
+ (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b
+ retvar <- genName
+ emit $ SVarDeclUninit (repTy (typeOf a)) retvar
+ emit $ SIf e1
+ (stmts2 <> pure (SAsg retvar e2))
+ (stmts3 <> pure (SAsg retvar e3))
+ return (CELit retvar)
+
ECase _ e a b -> do
let STEither t1 t2 = typeOf e
e1 <- compile' env e
@@ -209,14 +229,14 @@ compile' env = \case
(e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a
(e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
retvar <- genName
- emit $ SVarDeclUninit (genStructName (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1)
+ emit $ SVarDeclUninit (repTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
- (pure (SVarDecl (genStructName t1) fieldvar
+ (pure (SVarDecl True (repTy t1) fieldvar
(CEProj (CELit var) "a"))
<> stmts2
<> pure (SAsg retvar e2))
- (pure (SVarDecl (genStructName t2) fieldvar
+ (pure (SVarDecl True (repTy t2) fieldvar
(CEProj (CELit var) "b"))
<> stmts3
<> pure (SAsg retvar e3))))
@@ -238,12 +258,12 @@ compile' env = \case
(e2, stmts2) <- scope $ compile' env a
(e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b
retvar <- genName
- emit $ SVarDeclUninit (genStructName (typeOf a)) retvar
- emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1)
+ emit $ SVarDeclUninit (repTy (typeOf a)) retvar
+ emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1)
<> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0"))
(stmts2
<> pure (SAsg retvar e2))
- (pure (SVarDecl (genStructName (typeOf b)) fieldvar
+ (pure (SVarDecl True (repTy (typeOf b)) fieldvar
(CEProj (CELit var) "a"))
<> stmts3
<> pure (SAsg retvar e3))))
@@ -282,36 +302,14 @@ compile' env = \case
EShape _ e -> error "TODO" -- EShape ext (compile' e)
+ EOp _ op (EPair _ e1 e2) -> do
+ e1' <- compile' env e1
+ e2' <- compile' env e2
+ compileOpPair op e1' e2'
+
EOp _ op e -> do
- e1 <- compile' env e
- let unary cop = return @(State CompState) $ CECall cop [e1]
- let binary cop = do
- name <- genName
- emit $ SVarDecl (genStructName (typeOf e)) name e1
- return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
- case op of
- OAdd _ -> binary "+"
- OMul _ -> binary "*"
- ONeg _ -> unary "-"
- OLt _ -> binary "<"
- OLe _ -> binary "<="
- OEq _ -> binary "=="
- ONot -> unary "!"
- OAnd -> binary "&&"
- OOr -> binary "||"
- OIf -> do
- name <- emitStruct (STEither STNil STNil)
- _ <- emitStruct STNil
- return $ CEIf e1 (CEStruct name [("tag", CELit "0")])
- (CEStruct name [("tag", CELit "1")])
- ORound64 -> unary "(int64_t)round" -- ew
- OToFl64 -> unary "(double)"
- ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1
- OExp STF32 -> unary "expf"
- OExp STF64 -> unary "exp"
- OLog STF32 -> unary "logf"
- OLog STF64 -> unary "log"
- OIDiv _ -> binary "/"
+ e' <- compile' env e
+ compileOpGeneral op e'
ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2)
@@ -334,7 +332,7 @@ compileOpGeneral op e1 = do
let unary cop = return @(State CompState) $ CECall cop [e1]
let binary cop = do
name <- genName
- emit $ SVarDecl (genStructName (opt1 op)) name e1
+ emit $ SVarDecl True (repTy (opt1 op)) name e1
return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b")
case op of
OAdd _ -> binary "+"
diff --git a/src/Data.hs b/src/Data.hs
index fc39814..0be9046 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -10,6 +10,7 @@
{-# LANGUAGE TypeOperators #-}
module Data (module Data, (:~:)(Refl)) where
+import Data.Functor.Product
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
@@ -30,6 +31,14 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list
slistMap _ SNil = SNil
slistMap f (SCons x list) = SCons (f x) (slistMap f list)
+slistMapA :: Applicative m => (forall t. f t -> m (g t)) -> SList f list -> m (SList g list)
+slistMapA _ SNil = pure SNil
+slistMapA f (SCons x list) = SCons <$> f x <*> slistMapA f list
+
+slistZip :: SList f list -> SList g list -> SList (Product f g) list
+slistZip SNil SNil = SNil
+slistZip (x `SCons` l1) (y `SCons` l2) = Pair x y `SCons` slistZip l1 l2
+
unSList :: (forall t. f t -> a) -> SList f list -> [a]
unSList _ SNil = []
unSList f (x `SCons` l) = f x : unSList f l