summaryrefslogtreecommitdiff
path: root/src/Compile.hs
blob: 2fcff5d652a453d519614895b5be9d1e1595d461 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
module Compile where

import Control.Monad (ap)
import Data.Kind (Type)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import AST


data Body = Body [Stm] Inline  -- body, return expr
  deriving (Show)

data Stm
  = VarDef String String (Maybe Inline)  -- type, name, initialiser
  | Launch Inline Inline Body  -- num blocks, block size, kernel function body
  deriving (Show)

-- inline cuda expression
data Inline
  = IOp Inline String Inline
  | IUOp String Inline
  | ILit String
  | IVar String
  | ICall Inline [Inline]
  deriving (Show)

data Target = Host | Device
  deriving (Show)

data FunDef = FunDef Target String [String] Body  -- name, params (full declarations), body
  deriving (Show)

type Env :: [Ty] -> Type -> Type
data Env env v where
  ETop :: Env '[] v
  EPush :: v -> Env env v -> Env (t : env) v

prj :: Env env v -> Idx env t -> v
prj = \env idx -> go idx env
  where go :: Idx env t -> Env env v -> v
        go IZ (EPush v _) = v
        go (IS i) (EPush _ env) = go i env

-- generated global function definitions, generated local statements, function typedef cache (name, decl)
newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a))
  deriving (Functor)
instance Applicative M where
  pure x = M (\i m -> ([], [], m, i, x))
  (<*>) = ap
instance Monad M where
  M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m
                             (d2, s2, m2, i2, y) = let M h = g x in h i1 m1
                         in (d1 <> d2, s1 <> s2, m2, i2, y))

emitFun :: FunDef -> M ()
emitFun fd = M (\i m -> ([fd], [], m, i, ()))

emitStm :: Stm -> M ()
emitStm stm = M (\i m -> ([], [stm], m, i, ()))

captureStms :: M a -> M ([Stm], a)
captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m
                               in (d, [], m2, i2, (s, x)))

genId :: M Int
genId = M (\i m -> ([], [], m, i + 1, i))

getTypedef :: Ty -> M (Maybe String)
getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m)

putTypedef :: Ty -> String -> String -> M ()
putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ())

genName :: String -> M String
genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
  where (sep, suf) = case reverse s of
                       [] -> ("x", "_")
                       c : _ | c `elem` "0123456789_" -> ("_", "")
                             | otherwise              -> ("", "")

-- Function values are returned as a function-pointer-typed expression
compile :: Target -> Env env String -> Expr x env t -> M Inline
compile tgt env = \case
  EVar _ _ i -> pure $ IVar (prj env i)
  ELet _ rhs e -> do
    rhsi <- compile tgt env rhs
    var <- genName "x"
    rhsty <- writeType (typeOf rhs)
    emitStm $ VarDef rhsty var (Just rhsi)
    compile tgt (EPush var env) e

  EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e
  EBuild x n k e -> case tgt of
    Host -> do
      fname <- genName "buildfun"
      let n' = fromNat (unSNat n)
          shapevars = ['s' : show i | i <- [0 .. n' - 1]]
      emitFun $ FunDef Device fname (map ("int " ++) shapevars) _
      emitStm $ Launch _ _ _
      _
    Device -> _

writeType :: STy t -> M String
writeType = \case
  STArr _ t -> (++ "*") <$> writeType t
  STNil -> pure "Nil"
  STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b
  STScal t -> case t of
    STI32 -> pure "int32_t"
    STI64 -> pure "int64_t"
    STF32 -> pure "float"
    STF64 -> pure "double"
    STBool -> pure "bool"