summaryrefslogtreecommitdiff
path: root/src/CompileCu.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-12-06 21:23:05 +0100
committerTom Smeding <tom@tomsmeding.com>2024-12-06 21:23:05 +0100
commit263a3cef7543dfd447d5d75cc759fe95f7864105 (patch)
tree6da23b33013dbfe5315f337e443f12ecb67e9af3 /src/CompileCu.hs
parent3e266262ebe65bd5d775711b4d05bc9670a38a47 (diff)
Rename Compile -> CompileCu
Diffstat (limited to 'src/CompileCu.hs')
-rw-r--r--src/CompileCu.hs114
1 files changed, 114 insertions, 0 deletions
diff --git a/src/CompileCu.hs b/src/CompileCu.hs
new file mode 100644
index 0000000..749368a
--- /dev/null
+++ b/src/CompileCu.hs
@@ -0,0 +1,114 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+module Compile where
+
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict
+import Control.Monad.Trans.Writer.CPS
+import Data.Kind (Type)
+
+import AST
+import Data
+
+
+data Body = Body [Stm] Inline -- body, return expr
+ deriving (Show)
+
+data Stm
+ = VarDef String String (Maybe Inline) -- type, name, initialiser
+ | Launch Inline Inline Body -- num blocks, block size, kernel function body
+ deriving (Show)
+
+-- inline cuda expression
+data Inline
+ = IOp Inline String Inline
+ | IUOp String Inline
+ | ILit String
+ | IVar String
+ | ICall Inline [Inline]
+ deriving (Show)
+
+data Target = Host | Device
+ deriving (Show)
+
+data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body
+ deriving (Show)
+
+type Env :: [Ty] -> Type -> Type
+data Env env v where
+ ETop :: Env '[] v
+ EPush :: v -> Env env v -> Env (t : env) v
+
+prj :: Env env v -> Idx env t -> v
+prj = \env idx -> go idx env
+ where go :: Idx env t -> Env env v -> v
+ go IZ (EPush v _) = v
+ go (IS i) (EPush _ env) = go i env
+
+newtype M a = M (StateT Int -- ID generator
+ (WriterT [FunDef] -- generated global function definitions
+ (Writer [Stm])) -- generated local statements
+ a)
+ deriving newtype (Functor, Applicative, Monad)
+
+emitFun :: FunDef -> M ()
+emitFun fd = M (lift (tell [fd]))
+
+emitStm :: Stm -> M ()
+emitStm stm = M (lift (lift (tell [stm])))
+
+captureStms :: M a -> M ([Stm], a)
+captureStms (M m) = M (mapStateT (mapWriterT (mapWriter (\(((x, i), fds), stms) -> ((((stms, x), i), fds), [])))) m)
+
+genId :: M Int
+genId = M (state (\i -> (i, i + 1)))
+
+genName :: String -> M String
+genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
+ where (sep, suf) = case reverse s of
+ [] -> ("x", "_")
+ c : _ | c `elem` "0123456789_" -> ("_", "")
+ | otherwise -> ("", "")
+
+compile :: Target -> Env env String -> Ex env t -> M Inline
+compile tgt env = \case
+ EVar _ _ i -> pure $ IVar (prj env i)
+ ELet _ rhs e -> do
+ rhsi <- compile tgt env rhs
+ var <- genName "x"
+ rhsty <- writeType (typeOf rhs)
+ emitStm $ VarDef rhsty var (Just rhsi)
+ compile tgt (EPush var env) e
+
+ EBuild1 x k e -> compile tgt env $ EBuild x (k :< VNil) e
+ EBuild _ VNil e -> _
+ EBuild _ k e -> case tgt of
+ Host -> do
+ fname <- genName "buildfun"
+ let n' = fromSNat (vecLength k)
+ shapevars = ['s' : show i | i <- [0 .. n' - 1]]
+ _ = foldr (\a b -> EOp ext (OMul STI64) (EPair ext a b)) (EConst ext STI64 1) k
+ emitFun $ FunDef Device fname (map ("int " ++) shapevars) _
+ emitStm $ Launch _ (ILit "32") _
+ _
+ Device -> _
+
+ _ -> undefined
+
+writeType :: STy t -> M String
+writeType = \case
+ STArr _ t -> (++ "*") <$> writeType t
+ STNil -> pure "Nil"
+ STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b
+ STScal t -> case t of
+ STI32 -> pure "int32_t"
+ STI64 -> pure "int64_t"
+ STF32 -> pure "float"
+ STF64 -> pure "double"
+ STBool -> pure "bool"