diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2023-09-10 21:13:14 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2023-09-10 21:13:14 +0200 |
commit | 0bf9f5bb8a0873cad2e11faf83519b6e7ccf87d2 (patch) | |
tree | 41ddd52b0293319834b2130814414e76de434396 /src/Compile.hs |
Initial
Diffstat (limited to 'src/Compile.hs')
-rw-r--r-- | src/Compile.hs | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/src/Compile.hs b/src/Compile.hs new file mode 100644 index 0000000..2fcff5d --- /dev/null +++ b/src/Compile.hs @@ -0,0 +1,120 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE DeriveFunctor #-} +module Compile where + +import Control.Monad (ap) +import Data.Kind (Type) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map + +import AST + + +data Body = Body [Stm] Inline -- body, return expr + deriving (Show) + +data Stm + = VarDef String String (Maybe Inline) -- type, name, initialiser + | Launch Inline Inline Body -- num blocks, block size, kernel function body + deriving (Show) + +-- inline cuda expression +data Inline + = IOp Inline String Inline + | IUOp String Inline + | ILit String + | IVar String + | ICall Inline [Inline] + deriving (Show) + +data Target = Host | Device + deriving (Show) + +data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body + deriving (Show) + +type Env :: [Ty] -> Type -> Type +data Env env v where + ETop :: Env '[] v + EPush :: v -> Env env v -> Env (t : env) v + +prj :: Env env v -> Idx env t -> v +prj = \env idx -> go idx env + where go :: Idx env t -> Env env v -> v + go IZ (EPush v _) = v + go (IS i) (EPush _ env) = go i env + +-- generated global function definitions, generated local statements, function typedef cache (name, decl) +newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a)) + deriving (Functor) +instance Applicative M where + pure x = M (\i m -> ([], [], m, i, x)) + (<*>) = ap +instance Monad M where + M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m + (d2, s2, m2, i2, y) = let M h = g x in h i1 m1 + in (d1 <> d2, s1 <> s2, m2, i2, y)) + +emitFun :: FunDef -> M () +emitFun fd = M (\i m -> ([fd], [], m, i, ())) + +emitStm :: Stm -> M () +emitStm stm = M (\i m -> ([], [stm], m, i, ())) + +captureStms :: M a -> M ([Stm], a) +captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m + in (d, [], m2, i2, (s, x))) + +genId :: M Int +genId = M (\i m -> ([], [], m, i + 1, i)) + +getTypedef :: Ty -> M (Maybe String) +getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m) + +putTypedef :: Ty -> String -> String -> M () +putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ()) + +genName :: String -> M String +genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId + where (sep, suf) = case reverse s of + [] -> ("x", "_") + c : _ | c `elem` "0123456789_" -> ("_", "") + | otherwise -> ("", "") + +-- Function values are returned as a function-pointer-typed expression +compile :: Target -> Env env String -> Expr x env t -> M Inline +compile tgt env = \case + EVar _ _ i -> pure $ IVar (prj env i) + ELet _ rhs e -> do + rhsi <- compile tgt env rhs + var <- genName "x" + rhsty <- writeType (typeOf rhs) + emitStm $ VarDef rhsty var (Just rhsi) + compile tgt (EPush var env) e + + EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e + EBuild x n k e -> case tgt of + Host -> do + fname <- genName "buildfun" + let n' = fromNat (unSNat n) + shapevars = ['s' : show i | i <- [0 .. n' - 1]] + emitFun $ FunDef Device fname (map ("int " ++) shapevars) _ + emitStm $ Launch _ _ _ + _ + Device -> _ + +writeType :: STy t -> M String +writeType = \case + STArr _ t -> (++ "*") <$> writeType t + STNil -> pure "Nil" + STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b + STScal t -> case t of + STI32 -> pure "int32_t" + STI64 -> pure "int64_t" + STF32 -> pure "float" + STF64 -> pure "double" + STBool -> pure "bool" |