diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-12-12 22:40:17 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-12 22:40:17 +0100 |
commit | 27e5422c541623fbee36f2eedf37bb3d2ca3d14c (patch) | |
tree | aa571776e2afed4131b90787a7279c755ab9647c | |
parent | f323076ddf6fbea9f7a1a4dfeec98629459c49fc (diff) |
Improve Compile a little (still only scalars)
-rw-r--r-- | src/Compile.hs | 112 | ||||
-rw-r--r-- | src/Data.hs | 9 |
2 files changed, 64 insertions, 57 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 83c25c3..0db0d0f 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,12 +1,14 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE GADTs #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeApplications #-} module Compile where import Control.Monad.Trans.State.Strict import Data.Foldable (toList) import Data.Functor.Const +import qualified Data.Functor.Product as Product import Data.List (intersperse, intercalate) import qualified Data.Map.Strict as Map import Data.Map.Strict (Map) @@ -26,7 +28,7 @@ data StructDecl = StructDecl deriving (Show) data Stmt - = SVarDecl String String CExpr -- ^ type, variable name, right-hand side + = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side | SVarDeclUninit String String -- ^ type, variable name (no initialiser) | SAsg String CExpr -- ^ variable name, right-hand side | SBlock [Stmt] @@ -50,7 +52,7 @@ printStructDecl (StructDecl name contents comment) = printStmt :: Int -> Stmt -> ShowS printStmt indent = \case - SVarDecl typ name rhs -> showString (typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" + SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr rhs . showString ";" SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") SAsg name rhs -> showString (name ++ " = ") . printCExpr rhs . showString ";" SBlock stmts -> @@ -92,15 +94,15 @@ genStructName = \t -> "ty_" ++ gen t where -- all tags start with a letter, so the array mangling is unambiguous. gen :: STy t -> String gen STNil = "n" - gen (STPair a b) = 'p' : gen a ++ gen b - gen (STEither a b) = 'e' : gen a ++ gen b - gen (STMaybe t) = 'm' : gen t + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t gen (STScal st) = case st of - STI32 -> "i4" - STI64 -> "i8" - STF32 -> "f4" - STF64 -> "f8" + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" STBool -> "b" gen (STAccum t) = 'C' : gen t @@ -110,20 +112,20 @@ genStruct topty = case topty of Map.singleton (genStructName STNil) (StructDecl (genStructName STNil) "" com) STPair a b -> let name = genStructName (STPair a b) - in Map.singleton name (StructDecl name (genStructName a ++ " a; " ++ genStructName b ++ " b;") com) + in Map.singleton name (StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com) STEither a b -> let name = genStructName (STEither a b) -- 0 -> a, 1 -> b - in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ genStructName a ++ " a; " ++ genStructName b ++ " b; };") com) + in Map.singleton name (StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com) STMaybe t -> let name = genStructName (STMaybe t) -- 0 -> nothing, 1 -> just - in Map.singleton name (StructDecl name ("uint8_t tag; " ++ genStructName t ++ " a;") com) + in Map.singleton name (StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com) STArr n t -> let name = genStructName (STArr n t) - in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ genStructName t ++ " *a;") com) + in Map.singleton name (StructDecl name ("size_t sh[" ++ show (fromSNat n) ++ "]; " ++ repTy t ++ " *a;") com) STScal _ -> mempty STAccum t -> let name = genStructName (STAccum t) - in Map.singleton name (StructDecl name (genStructName t ++ " a;") com) + in Map.singleton name (StructDecl name (repTy t ++ " a;") com) <> genStruct t where com = ppTy 0 topty @@ -157,13 +159,20 @@ emitStruct ty = do modify $ \s -> s { csStructs = genStruct ty <> csStructs s } return (genStructName ty) -compile :: SList (Const String) env -> Ex env t -> String +nameEnv :: SList f env -> SList (Const String) env +nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) + +compile :: SList STy env -> Ex env t -> String compile env expr = - let (res, s) = runState (compile' env expr) (CompState mempty mempty 1) + let args = nameEnv env + (res, s) = runState (compile' args expr) (CompState mempty mempty 1) in ($ "") $ compose [compose $ map (\sd -> printStructDecl sd . showString "\n") (Map.elems (csStructs s)) ,showString "\n" - ,showString (genStructName (typeOf expr) ++ " kernel(" ++ intercalate ", " (reverse (unSList getConst env)) ++ ") {\n") + ,showString $ + repTy (typeOf expr) ++ " kernel(" ++ + intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repTy t ++ " " ++ getConst n) (slistZip env args))) ++ + ") {\n" ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) ,showString (" return ") . printCExpr res . showString ";\n}\n"] @@ -174,7 +183,7 @@ compile' env = \case ELet _ rhs body -> do e <- compile' env rhs var <- genName - emit $ SVarDecl (genStructName (typeOf rhs)) var e + emit $ SVarDecl True (repTy (typeOf rhs)) var e compile' (Const var `SCons` env) body EPair _ a b -> do @@ -201,6 +210,17 @@ compile' env = \case e2 <- compile' env e return $ CEStruct name [("tag", CELit "1"), ("b", e2)] + ECase _ (EOp _ OIf e) a b -> do + e1 <- compile' env e + (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you + (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b + retvar <- genName + emit $ SVarDeclUninit (repTy (typeOf a)) retvar + emit $ SIf e1 + (stmts2 <> pure (SAsg retvar e2)) + (stmts3 <> pure (SAsg retvar e3)) + return (CELit retvar) + ECase _ e a b -> do let STEither t1 t2 = typeOf e e1 <- compile' env e @@ -209,14 +229,14 @@ compile' env = \case (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (genStructName (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) + emit $ SVarDeclUninit (repTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (pure (SVarDecl (genStructName t1) fieldvar + (pure (SVarDecl True (repTy t1) fieldvar (CEProj (CELit var) "a")) <> stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl (genStructName t2) fieldvar + (pure (SVarDecl True (repTy t2) fieldvar (CEProj (CELit var) "b")) <> stmts3 <> pure (SAsg retvar e3)))) @@ -238,12 +258,12 @@ compile' env = \case (e2, stmts2) <- scope $ compile' env a (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b retvar <- genName - emit $ SVarDeclUninit (genStructName (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl (genStructName (typeOf e)) var e1) + emit $ SVarDeclUninit (repTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repTy (typeOf e)) var e1) <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) (stmts2 <> pure (SAsg retvar e2)) - (pure (SVarDecl (genStructName (typeOf b)) fieldvar + (pure (SVarDecl True (repTy (typeOf b)) fieldvar (CEProj (CELit var) "a")) <> stmts3 <> pure (SAsg retvar e3)))) @@ -282,36 +302,14 @@ compile' env = \case EShape _ e -> error "TODO" -- EShape ext (compile' e) + EOp _ op (EPair _ e1 e2) -> do + e1' <- compile' env e1 + e2' <- compile' env e2 + compileOpPair op e1' e2' + EOp _ op e -> do - e1 <- compile' env e - let unary cop = return @(State CompState) $ CECall cop [e1] - let binary cop = do - name <- genName - emit $ SVarDecl (genStructName (typeOf e)) name e1 - return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") - case op of - OAdd _ -> binary "+" - OMul _ -> binary "*" - ONeg _ -> unary "-" - OLt _ -> binary "<" - OLe _ -> binary "<=" - OEq _ -> binary "==" - ONot -> unary "!" - OAnd -> binary "&&" - OOr -> binary "||" - OIf -> do - name <- emitStruct (STEither STNil STNil) - _ <- emitStruct STNil - return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) - (CEStruct name [("tag", CELit "1")]) - ORound64 -> unary "(int64_t)round" -- ew - OToFl64 -> unary "(double)" - ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 - OExp STF32 -> unary "expf" - OExp STF64 -> unary "exp" - OLog STF32 -> unary "logf" - OLog STF64 -> unary "log" - OIDiv _ -> binary "/" + e' <- compile' env e + compileOpGeneral op e' ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) @@ -334,7 +332,7 @@ compileOpGeneral op e1 = do let unary cop = return @(State CompState) $ CECall cop [e1] let binary cop = do name <- genName - emit $ SVarDecl (genStructName (opt1 op)) name e1 + emit $ SVarDecl True (repTy (opt1 op)) name e1 return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") case op of OAdd _ -> binary "+" diff --git a/src/Data.hs b/src/Data.hs index fc39814..0be9046 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TypeOperators #-} module Data (module Data, (:~:)(Refl)) where +import Data.Functor.Product import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) @@ -30,6 +31,14 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list slistMap _ SNil = SNil slistMap f (SCons x list) = SCons (f x) (slistMap f list) +slistMapA :: Applicative m => (forall t. f t -> m (g t)) -> SList f list -> m (SList g list) +slistMapA _ SNil = pure SNil +slistMapA f (SCons x list) = SCons <$> f x <*> slistMapA f list + +slistZip :: SList f list -> SList g list -> SList (Product f g) list +slistZip SNil SNil = SNil +slistZip (x `SCons` l1) (y `SCons` l2) = Pair x y `SCons` slistZip l1 l2 + unSList :: (forall t. f t -> a) -> SList f list -> [a] unSList _ SNil = [] unSList f (x `SCons` l) = f x : unSList f l |