{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Compile where import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS import Data.Kind (Type) import AST import Data data Body = Body [Stm] Inline -- body, return expr deriving (Show) data Stm = VarDef String String (Maybe Inline) -- type, name, initialiser | Launch Inline Inline Body -- num blocks, block size, kernel function body deriving (Show) -- inline cuda expression data Inline = IOp Inline String Inline | IUOp String Inline | ILit String | IVar String | ICall Inline [Inline] deriving (Show) data Target = Host | Device deriving (Show) data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body deriving (Show) type Env :: [Ty] -> Type -> Type data Env env v where ETop :: Env '[] v EPush :: v -> Env env v -> Env (t : env) v prj :: Env env v -> Idx env t -> v prj = \env idx -> go idx env where go :: Idx env t -> Env env v -> v go IZ (EPush v _) = v go (IS i) (EPush _ env) = go i env newtype M a = M (StateT Int -- ID generator (WriterT [FunDef] -- generated global function definitions (Writer [Stm])) -- generated local statements a) deriving newtype (Functor, Applicative, Monad) emitFun :: FunDef -> M () emitFun fd = M (lift (tell [fd])) emitStm :: Stm -> M () emitStm stm = M (lift (lift (tell [stm]))) captureStms :: M a -> M ([Stm], a) captureStms (M m) = M (mapStateT (mapWriterT (mapWriter (\(((x, i), fds), stms) -> ((((stms, x), i), fds), [])))) m) genId :: M Int genId = M (state (\i -> (i, i + 1))) genName :: String -> M String genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId where (sep, suf) = case reverse s of [] -> ("x", "_") c : _ | c `elem` "0123456789_" -> ("_", "") | otherwise -> ("", "") compile :: Target -> Env env String -> Ex env t -> M Inline compile tgt env = \case EVar _ _ i -> pure $ IVar (prj env i) ELet _ rhs e -> do rhsi <- compile tgt env rhs var <- genName "x" rhsty <- writeType (typeOf rhs) emitStm $ VarDef rhsty var (Just rhsi) compile tgt (EPush var env) e EBuild1 x k e -> compile tgt env $ EBuild x (k :< VNil) e EBuild _ VNil e -> _ EBuild _ k e -> case tgt of Host -> do fname <- genName "buildfun" let n' = fromSNat (vecLength k) shapevars = ['s' : show i | i <- [0 .. n' - 1]] _ = foldr (\a b -> EOp ext (OMul STI64) (EPair ext a b)) (EConst ext STI64 1) k emitFun $ FunDef Device fname (map ("int " ++) shapevars) _ emitStm $ Launch _ (ILit "32") _ _ Device -> _ _ -> undefined writeType :: STy t -> M String writeType = \case STArr _ t -> (++ "*") <$> writeType t STNil -> pure "Nil" STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b STScal t -> case t of STI32 -> pure "int32_t" STI64 -> pure "int64_t" STF32 -> pure "float" STF64 -> pure "double" STBool -> pure "bool"