1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Compile where
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
import Data.Kind (Type)
import AST
import Data
data Body = Body [Stm] Inline -- body, return expr
deriving (Show)
data Stm
= VarDef String String (Maybe Inline) -- type, name, initialiser
| Launch Inline Inline Body -- num blocks, block size, kernel function body
deriving (Show)
-- inline cuda expression
data Inline
= IOp Inline String Inline
| IUOp String Inline
| ILit String
| IVar String
| ICall Inline [Inline]
deriving (Show)
data Target = Host | Device
deriving (Show)
data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body
deriving (Show)
type Env :: [Ty] -> Type -> Type
data Env env v where
ETop :: Env '[] v
EPush :: v -> Env env v -> Env (t : env) v
prj :: Env env v -> Idx env t -> v
prj = \env idx -> go idx env
where go :: Idx env t -> Env env v -> v
go IZ (EPush v _) = v
go (IS i) (EPush _ env) = go i env
newtype M a = M (StateT Int -- ID generator
(WriterT [FunDef] -- generated global function definitions
(Writer [Stm])) -- generated local statements
a)
deriving newtype (Functor, Applicative, Monad)
emitFun :: FunDef -> M ()
emitFun fd = M (lift (tell [fd]))
emitStm :: Stm -> M ()
emitStm stm = M (lift (lift (tell [stm])))
captureStms :: M a -> M ([Stm], a)
captureStms (M m) = M (mapStateT (mapWriterT (mapWriter (\(((x, i), fds), stms) -> ((((stms, x), i), fds), [])))) m)
genId :: M Int
genId = M (state (\i -> (i, i + 1)))
genName :: String -> M String
genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
where (sep, suf) = case reverse s of
[] -> ("x", "_")
c : _ | c `elem` "0123456789_" -> ("_", "")
| otherwise -> ("", "")
compile :: Target -> Env env String -> Ex env t -> M Inline
compile tgt env = \case
EVar _ _ i -> pure $ IVar (prj env i)
ELet _ rhs e -> do
rhsi <- compile tgt env rhs
var <- genName "x"
rhsty <- writeType (typeOf rhs)
emitStm $ VarDef rhsty var (Just rhsi)
compile tgt (EPush var env) e
EBuild1 x k e -> compile tgt env $ EBuild x (k :< VNil) e
EBuild _ VNil e -> _
EBuild _ k e -> case tgt of
Host -> do
fname <- genName "buildfun"
let n' = fromSNat (vecLength k)
shapevars = ['s' : show i | i <- [0 .. n' - 1]]
_ = foldr (\a b -> EOp ext (OMul STI64) (EPair ext a b)) (EConst ext STI64 1) k
emitFun $ FunDef Device fname (map ("int " ++) shapevars) _
emitStm $ Launch _ (ILit "32") _
_
Device -> _
_ -> undefined
writeType :: STy t -> M String
writeType = \case
STArr _ t -> (++ "*") <$> writeType t
STNil -> pure "Nil"
STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b
STScal t -> case t of
STI32 -> pure "int32_t"
STI64 -> pure "int64_t"
STF32 -> pure "float"
STF64 -> pure "double"
STBool -> pure "bool"
|