summaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2023-09-10 21:13:14 +0200
committerTom Smeding <t.j.smeding@uu.nl>2023-09-10 21:13:14 +0200
commit0bf9f5bb8a0873cad2e11faf83519b6e7ccf87d2 (patch)
tree41ddd52b0293319834b2130814414e76de434396 /src/Compile.hs
Initial
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs120
1 files changed, 120 insertions, 0 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
new file mode 100644
index 0000000..2fcff5d
--- /dev/null
+++ b/src/Compile.hs
@@ -0,0 +1,120 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE DeriveFunctor #-}
+module Compile where
+
+import Control.Monad (ap)
+import Data.Kind (Type)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+
+import AST
+
+
+data Body = Body [Stm] Inline -- body, return expr
+ deriving (Show)
+
+data Stm
+ = VarDef String String (Maybe Inline) -- type, name, initialiser
+ | Launch Inline Inline Body -- num blocks, block size, kernel function body
+ deriving (Show)
+
+-- inline cuda expression
+data Inline
+ = IOp Inline String Inline
+ | IUOp String Inline
+ | ILit String
+ | IVar String
+ | ICall Inline [Inline]
+ deriving (Show)
+
+data Target = Host | Device
+ deriving (Show)
+
+data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body
+ deriving (Show)
+
+type Env :: [Ty] -> Type -> Type
+data Env env v where
+ ETop :: Env '[] v
+ EPush :: v -> Env env v -> Env (t : env) v
+
+prj :: Env env v -> Idx env t -> v
+prj = \env idx -> go idx env
+ where go :: Idx env t -> Env env v -> v
+ go IZ (EPush v _) = v
+ go (IS i) (EPush _ env) = go i env
+
+-- generated global function definitions, generated local statements, function typedef cache (name, decl)
+newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a))
+ deriving (Functor)
+instance Applicative M where
+ pure x = M (\i m -> ([], [], m, i, x))
+ (<*>) = ap
+instance Monad M where
+ M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m
+ (d2, s2, m2, i2, y) = let M h = g x in h i1 m1
+ in (d1 <> d2, s1 <> s2, m2, i2, y))
+
+emitFun :: FunDef -> M ()
+emitFun fd = M (\i m -> ([fd], [], m, i, ()))
+
+emitStm :: Stm -> M ()
+emitStm stm = M (\i m -> ([], [stm], m, i, ()))
+
+captureStms :: M a -> M ([Stm], a)
+captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m
+ in (d, [], m2, i2, (s, x)))
+
+genId :: M Int
+genId = M (\i m -> ([], [], m, i + 1, i))
+
+getTypedef :: Ty -> M (Maybe String)
+getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m)
+
+putTypedef :: Ty -> String -> String -> M ()
+putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ())
+
+genName :: String -> M String
+genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
+ where (sep, suf) = case reverse s of
+ [] -> ("x", "_")
+ c : _ | c `elem` "0123456789_" -> ("_", "")
+ | otherwise -> ("", "")
+
+-- Function values are returned as a function-pointer-typed expression
+compile :: Target -> Env env String -> Expr x env t -> M Inline
+compile tgt env = \case
+ EVar _ _ i -> pure $ IVar (prj env i)
+ ELet _ rhs e -> do
+ rhsi <- compile tgt env rhs
+ var <- genName "x"
+ rhsty <- writeType (typeOf rhs)
+ emitStm $ VarDef rhsty var (Just rhsi)
+ compile tgt (EPush var env) e
+
+ EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e
+ EBuild x n k e -> case tgt of
+ Host -> do
+ fname <- genName "buildfun"
+ let n' = fromNat (unSNat n)
+ shapevars = ['s' : show i | i <- [0 .. n' - 1]]
+ emitFun $ FunDef Device fname (map ("int " ++) shapevars) _
+ emitStm $ Launch _ _ _
+ _
+ Device -> _
+
+writeType :: STy t -> M String
+writeType = \case
+ STArr _ t -> (++ "*") <$> writeType t
+ STNil -> pure "Nil"
+ STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b
+ STScal t -> case t of
+ STI32 -> pure "int32_t"
+ STI64 -> pure "int64_t"
+ STF32 -> pure "float"
+ STF64 -> pure "double"
+ STBool -> pure "bool"