summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-28 16:11:14 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-28 16:11:14 +0100
commite40e33bc16175b6a0180e3ef88d0a588819e6c37 (patch)
treefff0a39039edc117d86a5da1cd793bb633cc91c9
parente9ab5d50093e48b64d758d67388ef32321ad1984 (diff)
Compile to C and load using dlopen
-rw-r--r--chad-fast.cabal5
-rw-r--r--src/Array.hs4
-rw-r--r--src/Compile.hs198
-rw-r--r--src/Compile/Exec.hs59
4 files changed, 259 insertions, 7 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 737ff62..8b212a5 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -27,6 +27,7 @@ library
CHAD.Top
CHAD.Types
Compile
+ Compile.Exec
-- CompileCu
Data
Example
@@ -49,9 +50,13 @@ library
base >= 4.19 && < 4.21,
containers,
deepseq,
+ directory,
-- template-haskell,
prettyprinter,
+ process,
+ some,
transformers,
+ unix,
vector,
ansi-terminal,
diff --git a/src/Array.hs b/src/Array.hs
index ef9bb8d..82c3f31 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -41,6 +41,10 @@ shapeSize :: Shape n -> Int
shapeSize ShNil = 1
shapeSize (ShCons sh n) = shapeSize sh * n
+shapeRank :: Shape n -> SNat n
+shapeRank ShNil = SZ
+shapeRank (sh `ShCons` _) = SS (shapeRank sh)
+
fromLinearIndex :: Shape n -> Int -> Index n
fromLinearIndex ShNil 0 = IxNil
fromLinearIndex ShNil _ = error "Index out of range"
diff --git a/src/Compile.hs b/src/Compile.hs
index 8d5fd13..92eb6d5 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -10,18 +10,51 @@ import Data.Bifunctor (first, second)
import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
+import Data.Functor.Product (Product)
import Data.List (intersperse, intercalate)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Set (Set)
+import Data.Some
+import qualified Data.Vector as V
+import Foreign
+import Array
import AST
import AST.Pretty (ppTy)
+import Compile.Exec
import Data
+import Interpreter.Rep
-- In shape and index arrays, the innermost dimension is on the right (last index).
+-- TODO: array lifetimes in C?
+
+
+compile :: SList STy env -> Ex env t
+ -> IO (SList Value env -> IO (Rep t))
+compile = \env expr -> do
+ lib <- buildKernel (compileToString env expr) ["kernel"]
+
+ let arg_metrics = reverse (unSList metricsSTy env)
+ (arg_offsets, result_offset) = computeStructOffsets arg_metrics
+ result_type = typeOf expr
+ result_size = sizeofSTy result_type
+
+ return $ \val -> do
+ allocaBytes (result_offset + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets
+ serialiseArguments args ptr $ do
+ callKernelFun "kernel" lib ptr
+ deserialise result_type ptr result_offset
+ where
+ serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
+ serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
+ serialise t arg ptr off $
+ serialiseArguments args ptr k
+ serialiseArguments _ _ k = k
+
data StructDecl = StructDecl
String -- ^ name
@@ -116,7 +149,7 @@ repTy (TScal st) = case st of
TI64 -> "int64_t"
TF32 -> "float"
TF64 -> "double"
- TBool -> "bool"
+ TBool -> "uint8_t"
repTy t = genStructName t
repSTy :: STy t -> String
@@ -218,20 +251,171 @@ emitStruct ty = do
nameEnv :: SList f env -> SList (Const String) env
nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))
-compile :: SList STy env -> Ex env t -> String
-compile env expr =
+compileToString :: SList STy env -> Ex env t -> String
+compileToString env expr =
let args = nameEnv env
(res, s) = runState (compile' args expr) (CompState mempty mempty 1)
structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))
+
+ (arg_pairs, arg_metrics) =
+ unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t))
+ (slistZip env args))
+ (arg_offsets, result_offset') = computeStructOffsets arg_metrics
+ result_offset = align (alignmentSTy (typeOf expr)) result_offset'
in ($ "") $ compose
- [compose $ map (\sd -> printStructDecl sd . showString "\n") structs
+ [showString "#include <stdint.h>\n"
+ ,showString "#include <stdlib.h>\n\n"
+ ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs
,showString "\n"
,showString $
- repSTy (typeOf expr) ++ " kernel(" ++
- intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repSTy t ++ " " ++ getConst n) (slistZip env args))) ++
+ "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++
+ intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++
") {\n"
,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s))
- ,showString (" return ") . printCExpr 0 res . showString ";\n}\n"]
+ ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n"
+ ,showString "void kernel(void *data) {\n"
+ -- Some code here assumes that we're on a 64-bit system, so let's check that
+ ,showString " if (sizeof(void*) != 8) { abort(); }\n"
+ ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
+ concat (map (\((arg, typ), off, idx) ->
+ "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
+ ++ (if idx < length arg_pairs - 1 then "," else "")
+ ++ " // " ++ arg)
+ (zip3 arg_pairs arg_offsets [0::Int ..])) ++
+ "\n );\n"
+ ,showString "}\n"]
+
+-- | Takes list of metrics (alignment, sizeof).
+-- Returns (offsets, size of struct).
+computeStructOffsets :: [(Int, Int)] -> ([Int], Int)
+computeStructOffsets = go 0 0
+ where
+ go off maxal [(al, sz)] =
+ ([off], align (max maxal al) (off + sz))
+ go off maxal ((al, sz) : pairs@((al2,_):_)) =
+ first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs
+ go _ _ [] = ([], 0)
+
+-- | Assumes that this is called at the correct alignment.
+serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r
+serialise topty topval ptr off k =
+ -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
+ case (topty, topval) of
+ (STNil, ()) -> k
+ (STPair a b, (x, y)) ->
+ serialise a x ptr off $
+ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k
+ (STEither a _, Left x) -> do
+ pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b})
+ serialise a x ptr (off + alignmentSTy topty) k
+ (STEither _ b, Right y) -> do
+ pokeByteOff ptr off (1 :: Word8)
+ serialise b y ptr (off + alignmentSTy topty) k
+ (STMaybe _, Nothing) -> do
+ pokeByteOff ptr off (0 :: Word8)
+ k
+ (STMaybe t, Just x) -> do
+ pokeByteOff ptr off (1 :: Word8)
+ serialise t x ptr (off + alignmentSTy t) k
+ (STArr n t, Array sh vec) -> do
+ pokeShape ptr off n sh
+ let off1 = off + 8 * fromSNat n
+ eltsz = sizeofSTy t
+ allocaBytes (shapeSize sh * sizeofSTy t) $ \arrptr ->
+ let loop i
+ | i == shapeSize sh = k
+ | otherwise =
+ serialise t (vec V.! i) arrptr (off1 + i * eltsz) $
+ loop (i+1)
+ in loop 0
+ (STScal sty, x) -> case sty of
+ STI32 -> pokeByteOff ptr off (x :: Int32) >> k
+ STI64 -> pokeByteOff ptr off (x :: Int64) >> k
+ STF32 -> pokeByteOff ptr off (x :: Float) >> k
+ STF64 -> pokeByteOff ptr off (x :: Double) >> k
+ STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k
+ (STAccum{}, _) -> error "Cannot serialise accumulators"
+
+-- | Assumes that this is called at the correct alignment.
+deserialise :: STy t -> Ptr () -> Int -> IO (Rep t)
+deserialise topty ptr off =
+ -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls
+ case topty of
+ STNil -> return ()
+ STPair a b -> do
+ x <- deserialise a ptr off
+ y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a))
+ return (x, y)
+ STEither a b -> do
+ tag <- peekByteOff @Word8 ptr off
+ if tag == 0 -- alignment of (a + b) is alignment of (union {a b})
+ then Left <$> deserialise a ptr (off + alignmentSTy topty)
+ else Right <$> deserialise b ptr (off + alignmentSTy topty)
+ STMaybe t -> do
+ tag <- peekByteOff @Word8 ptr off
+ if tag == 0
+ then return Nothing
+ else Just <$> deserialise t ptr (off + alignmentSTy t)
+ STArr n t -> do
+ sh <- peekShape ptr off n
+ let off1 = off + 8 * fromSNat n
+ eltsz = sizeofSTy t
+ Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t ptr (off1 + i * eltsz))
+ STScal sty -> case sty of
+ STI32 -> peekByteOff @Int32 ptr off
+ STI64 -> peekByteOff @Int64 ptr off
+ STF32 -> peekByteOff @Float ptr off
+ STF64 -> peekByteOff @Double ptr off
+ STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off
+ STAccum{} -> error "Cannot serialise accumulators"
+
+align :: Int -> Int -> Int
+align a off = (off + a - 1) `div` a * a
+
+alignmentSTy :: STy t -> Int
+alignmentSTy = fst . metricsSTy
+
+sizeofSTy :: STy t -> Int
+sizeofSTy = snd . metricsSTy
+
+-- | Returns (alignment, sizeof)
+metricsSTy :: STy t -> (Int, Int)
+metricsSTy STNil = (1, 0)
+metricsSTy (STPair a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, align (max a1 a2) (s1 + s2))
+metricsSTy (STEither a b) =
+ let (a1, s1) = metricsSTy a
+ (a2, s2) = metricsSTy b
+ in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned
+metricsSTy (STMaybe t) =
+ let (a, s) = metricsSTy t
+ in (a, a + s) -- the union after the tag byte is aligned
+metricsSTy (STArr n _) = (8, fromSNat n * 8 + 8)
+metricsSTy (STScal sty) = case sty of
+ STI32 -> (4, 4)
+ STI64 -> (8, 8)
+ STF32 -> (4, 4)
+ STF64 -> (8, 8)
+ STBool -> (1, 1) -- compiled to uint8_t
+metricsSTy (STAccum t) = metricsSTy t
+
+pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO ()
+pokeShape ptr off = go . fromSNat
+ where
+ go :: Int -> Shape n -> IO ()
+ go rank = \case
+ ShNil -> return ()
+ sh `ShCons` n -> do
+ pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64)
+ go (rank - 1) sh
+
+peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n)
+peekShape ptr off = \case
+ SZ -> return ShNil
+ SS n -> ShCons <$> peekShape ptr off n
+ <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8))
compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
compile' env = \case
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
new file mode 100644
index 0000000..163be2b
--- /dev/null
+++ b/src/Compile/Exec.hs
@@ -0,0 +1,59 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
+module Compile.Exec (
+ KernelLib,
+ buildKernel,
+ callKernelFun,
+) where
+
+import Data.IORef
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Foreign (Ptr)
+import Foreign.Ptr (FunPtr)
+import System.Directory (removeDirectoryRecursive)
+import System.Environment (lookupEnv)
+import System.IO (hPutStrLn, stderr)
+import System.Posix.DynamicLinker
+import System.Posix.Temp (mkdtemp)
+import System.Process (readProcess)
+
+
+-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)
+data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ()))))
+
+buildKernel :: String -> [String] -> IO KernelLib
+buildKernel csource funnames = do
+ template <- (++ "/tmp.chad.") <$> getTempDir
+ path <- mkdtemp template
+
+ let outso = path ++ "/out.so"
+ let args = ["-O3", "-march=native", "-shared", "-fPIC", "-x", "c", "-o", outso, "-"]
+ _ <- readProcess "gcc" args csource
+
+ hPutStrLn stderr $ "[chad] loading kernel " ++ path
+ dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL]
+
+ removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now
+
+ ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames]
+ ref <- newIORef ptrs
+ _ <- mkWeakIORef ref (do hPutStrLn stderr $ "[chad] unloading kernel " ++ path
+ dlclose dl)
+ return (KernelLib ref)
+
+foreign import ccall "dynamic"
+ wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO ()
+
+-- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive
+{-# NOINLINE callKernelFun #-}
+callKernelFun :: String -> KernelLib -> Ptr () -> IO ()
+callKernelFun key (KernelLib ref) arg = do
+ mp <- readIORef ref
+ wrapKernelFun (mp Map.! key) arg
+
+getTempDir :: IO FilePath
+getTempDir =
+ lookupEnv "TMPDIR" >>= \case
+ Just s | not (null s) -> return s
+ _ -> return "/tmp"