{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE DeriveFunctor #-} module Compile where import Control.Monad (ap) import Data.Kind (Type) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import AST data Body = Body [Stm] Inline -- body, return expr deriving (Show) data Stm = VarDef String String (Maybe Inline) -- type, name, initialiser | Launch Inline Inline Body -- num blocks, block size, kernel function body deriving (Show) -- inline cuda expression data Inline = IOp Inline String Inline | IUOp String Inline | ILit String | IVar String | ICall Inline [Inline] deriving (Show) data Target = Host | Device deriving (Show) data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body deriving (Show) type Env :: [Ty] -> Type -> Type data Env env v where ETop :: Env '[] v EPush :: v -> Env env v -> Env (t : env) v prj :: Env env v -> Idx env t -> v prj = \env idx -> go idx env where go :: Idx env t -> Env env v -> v go IZ (EPush v _) = v go (IS i) (EPush _ env) = go i env -- generated global function definitions, generated local statements, function typedef cache (name, decl) newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a)) deriving (Functor) instance Applicative M where pure x = M (\i m -> ([], [], m, i, x)) (<*>) = ap instance Monad M where M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m (d2, s2, m2, i2, y) = let M h = g x in h i1 m1 in (d1 <> d2, s1 <> s2, m2, i2, y)) emitFun :: FunDef -> M () emitFun fd = M (\i m -> ([fd], [], m, i, ())) emitStm :: Stm -> M () emitStm stm = M (\i m -> ([], [stm], m, i, ())) captureStms :: M a -> M ([Stm], a) captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m in (d, [], m2, i2, (s, x))) genId :: M Int genId = M (\i m -> ([], [], m, i + 1, i)) getTypedef :: Ty -> M (Maybe String) getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m) putTypedef :: Ty -> String -> String -> M () putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ()) genName :: String -> M String genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId where (sep, suf) = case reverse s of [] -> ("x", "_") c : _ | c `elem` "0123456789_" -> ("_", "") | otherwise -> ("", "") -- Function values are returned as a function-pointer-typed expression compile :: Target -> Env env String -> Expr x env t -> M Inline compile tgt env = \case EVar _ _ i -> pure $ IVar (prj env i) ELet _ rhs e -> do rhsi <- compile tgt env rhs var <- genName "x" rhsty <- writeType (typeOf rhs) emitStm $ VarDef rhsty var (Just rhsi) compile tgt (EPush var env) e EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e EBuild x n k e -> case tgt of Host -> do fname <- genName "buildfun" let n' = fromNat (unSNat n) shapevars = ['s' : show i | i <- [0 .. n' - 1]] emitFun $ FunDef Device fname (map ("int " ++) shapevars) _ emitStm $ Launch _ _ _ _ Device -> _ writeType :: STy t -> M String writeType = \case STArr _ t -> (++ "*") <$> writeType t STNil -> pure "Nil" STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b STScal t -> case t of STI32 -> pure "int32_t" STI64 -> pure "int64_t" STF32 -> pure "float" STF64 -> pure "double" STBool -> pure "bool"