diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 21:23:05 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 21:23:05 +0100 | 
| commit | 263a3cef7543dfd447d5d75cc759fe95f7864105 (patch) | |
| tree | 6da23b33013dbfe5315f337e443f12ecb67e9af3 /src/CompileCu.hs | |
| parent | 3e266262ebe65bd5d775711b4d05bc9670a38a47 (diff) | |
Rename Compile -> CompileCu
Diffstat (limited to 'src/CompileCu.hs')
| -rw-r--r-- | src/CompileCu.hs | 114 | 
1 files changed, 114 insertions, 0 deletions
| diff --git a/src/CompileCu.hs b/src/CompileCu.hs new file mode 100644 index 0000000..749368a --- /dev/null +++ b/src/CompileCu.hs @@ -0,0 +1,114 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +module Compile where + +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Control.Monad.Trans.Writer.CPS +import Data.Kind (Type) + +import AST +import Data + + +data Body = Body [Stm] Inline  -- body, return expr +  deriving (Show) + +data Stm +  = VarDef String String (Maybe Inline)  -- type, name, initialiser +  | Launch Inline Inline Body  -- num blocks, block size, kernel function body +  deriving (Show) + +-- inline cuda expression +data Inline +  = IOp Inline String Inline +  | IUOp String Inline +  | ILit String +  | IVar String +  | ICall Inline [Inline] +  deriving (Show) + +data Target = Host | Device +  deriving (Show) + +data FunDef = FunDef Target String [String] Body  -- name, params (full declarations), body +  deriving (Show) + +type Env :: [Ty] -> Type -> Type +data Env env v where +  ETop :: Env '[] v +  EPush :: v -> Env env v -> Env (t : env) v + +prj :: Env env v -> Idx env t -> v +prj = \env idx -> go idx env +  where go :: Idx env t -> Env env v -> v +        go IZ (EPush v _) = v +        go (IS i) (EPush _ env) = go i env + +newtype M a = M (StateT Int         -- ID generator +                 (WriterT [FunDef]  -- generated global function definitions +                  (Writer [Stm]))   -- generated local statements +                 a) +  deriving newtype (Functor, Applicative, Monad) + +emitFun :: FunDef -> M () +emitFun fd = M (lift (tell [fd])) + +emitStm :: Stm -> M () +emitStm stm = M (lift (lift (tell [stm]))) + +captureStms :: M a -> M ([Stm], a) +captureStms (M m) = M (mapStateT (mapWriterT (mapWriter (\(((x, i), fds), stms) -> ((((stms, x), i), fds), [])))) m) + +genId :: M Int +genId = M (state (\i -> (i, i + 1))) + +genName :: String -> M String +genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId +  where (sep, suf) = case reverse s of +                       [] -> ("x", "_") +                       c : _ | c `elem` "0123456789_" -> ("_", "") +                             | otherwise              -> ("", "") + +compile :: Target -> Env env String -> Ex env t -> M Inline +compile tgt env = \case +  EVar _ _ i -> pure $ IVar (prj env i) +  ELet _ rhs e -> do +    rhsi <- compile tgt env rhs +    var <- genName "x" +    rhsty <- writeType (typeOf rhs) +    emitStm $ VarDef rhsty var (Just rhsi) +    compile tgt (EPush var env) e + +  EBuild1 x k e -> compile tgt env $ EBuild x (k :< VNil) e +  EBuild _ VNil e -> _ +  EBuild _ k e -> case tgt of +    Host -> do +      fname <- genName "buildfun" +      let n' = fromSNat (vecLength k) +          shapevars = ['s' : show i | i <- [0 .. n' - 1]] +          _ = foldr (\a b -> EOp ext (OMul STI64) (EPair ext a b)) (EConst ext STI64 1) k +      emitFun $ FunDef Device fname (map ("int " ++) shapevars) _ +      emitStm $ Launch _ (ILit "32") _ +      _ +    Device -> _ + +  _ -> undefined + +writeType :: STy t -> M String +writeType = \case +  STArr _ t -> (++ "*") <$> writeType t +  STNil -> pure "Nil" +  STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b +  STScal t -> case t of +    STI32 -> pure "int32_t" +    STI64 -> pure "int64_t" +    STF32 -> pure "float" +    STF64 -> pure "double" +    STBool -> pure "bool" | 
