diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-28 16:11:14 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-28 16:11:14 +0100 |
commit | e40e33bc16175b6a0180e3ef88d0a588819e6c37 (patch) | |
tree | fff0a39039edc117d86a5da1cd793bb633cc91c9 | |
parent | e9ab5d50093e48b64d758d67388ef32321ad1984 (diff) |
Compile to C and load using dlopen
-rw-r--r-- | chad-fast.cabal | 5 | ||||
-rw-r--r-- | src/Array.hs | 4 | ||||
-rw-r--r-- | src/Compile.hs | 198 | ||||
-rw-r--r-- | src/Compile/Exec.hs | 59 |
4 files changed, 259 insertions, 7 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 737ff62..8b212a5 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -27,6 +27,7 @@ library CHAD.Top CHAD.Types Compile + Compile.Exec -- CompileCu Data Example @@ -49,9 +50,13 @@ library base >= 4.19 && < 4.21, containers, deepseq, + directory, -- template-haskell, prettyprinter, + process, + some, transformers, + unix, vector, ansi-terminal, diff --git a/src/Array.hs b/src/Array.hs index ef9bb8d..82c3f31 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -41,6 +41,10 @@ shapeSize :: Shape n -> Int shapeSize ShNil = 1 shapeSize (ShCons sh n) = shapeSize sh * n +shapeRank :: Shape n -> SNat n +shapeRank ShNil = SZ +shapeRank (sh `ShCons` _) = SS (shapeRank sh) + fromLinearIndex :: Shape n -> Int -> Index n fromLinearIndex ShNil 0 = IxNil fromLinearIndex ShNil _ = error "Index out of range" diff --git a/src/Compile.hs b/src/Compile.hs index 8d5fd13..92eb6d5 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -10,18 +10,51 @@ import Data.Bifunctor (first, second) import Data.Foldable (toList) import Data.Functor.Const import qualified Data.Functor.Product as Product +import Data.Functor.Product (Product) import Data.List (intersperse, intercalate) import qualified Data.Map.Strict as Map import qualified Data.Set as Set import Data.Set (Set) +import Data.Some +import qualified Data.Vector as V +import Foreign +import Array import AST import AST.Pretty (ppTy) +import Compile.Exec import Data +import Interpreter.Rep -- In shape and index arrays, the innermost dimension is on the right (last index). +-- TODO: array lifetimes in C? + + +compile :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compile = \env expr -> do + lib <- buildKernel (compileToString env expr) ["kernel"] + + let arg_metrics = reverse (unSList metricsSTy env) + (arg_offsets, result_offset) = computeStructOffsets arg_metrics + result_type = typeOf expr + result_size = sizeofSTy result_type + + return $ \val -> do + allocaBytes (result_offset + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) arg_offsets + serialiseArguments args ptr $ do + callKernelFun "kernel" lib ptr + deserialise result_type ptr result_offset + where + serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r + serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = + serialise t arg ptr off $ + serialiseArguments args ptr k + serialiseArguments _ _ k = k + data StructDecl = StructDecl String -- ^ name @@ -116,7 +149,7 @@ repTy (TScal st) = case st of TI64 -> "int64_t" TF32 -> "float" TF64 -> "double" - TBool -> "bool" + TBool -> "uint8_t" repTy t = genStructName t repSTy :: STy t -> String @@ -218,20 +251,171 @@ emitStruct ty = do nameEnv :: SList f env -> SList (Const String) env nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) -compile :: SList STy env -> Ex env t -> String -compile env expr = +compileToString :: SList STy env -> Ex env t -> String +compileToString env expr = let args = nameEnv env (res, s) = runState (compile' args expr) (CompState mempty mempty 1) structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env)) + + (arg_pairs, arg_metrics) = + unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) + (slistZip env args)) + (arg_offsets, result_offset') = computeStructOffsets arg_metrics + result_offset = align (alignmentSTy (typeOf expr)) result_offset' in ($ "") $ compose - [compose $ map (\sd -> printStructDecl sd . showString "\n") structs + [showString "#include <stdint.h>\n" + ,showString "#include <stdlib.h>\n\n" + ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs ,showString "\n" ,showString $ - repSTy (typeOf expr) ++ " kernel(" ++ - intercalate ", " (reverse (unSList (\(Product.Pair t n) -> repSTy t ++ " " ++ getConst n) (slistZip env args))) ++ + "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++ + intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ ") {\n" ,compose $ map (\st -> showString " " . printStmt 1 st . showString "\n") (toList (csStmts s)) - ,showString (" return ") . printCExpr 0 res . showString ";\n}\n"] + ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n" + ,showString "void kernel(void *data) {\n" + -- Some code here assumes that we're on a 64-bit system, so let's check that + ,showString " if (sizeof(void*) != 8) { abort(); }\n" + ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ + concat (map (\((arg, typ), off, idx) -> + "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ (if idx < length arg_pairs - 1 then "," else "") + ++ " // " ++ arg) + (zip3 arg_pairs arg_offsets [0::Int ..])) ++ + "\n );\n" + ,showString "}\n"] + +-- | Takes list of metrics (alignment, sizeof). +-- Returns (offsets, size of struct). +computeStructOffsets :: [(Int, Int)] -> ([Int], Int) +computeStructOffsets = go 0 0 + where + go off maxal [(al, sz)] = + ([off], align (max maxal al) (off + sz)) + go off maxal ((al, sz) : pairs@((al2,_):_)) = + first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs + go _ _ [] = ([], 0) + +-- | Assumes that this is called at the correct alignment. +serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r +serialise topty topval ptr off k = + -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls + case (topty, topval) of + (STNil, ()) -> k + (STPair a b, (x, y)) -> + serialise a x ptr off $ + serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k + (STEither a _, Left x) -> do + pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b}) + serialise a x ptr (off + alignmentSTy topty) k + (STEither _ b, Right y) -> do + pokeByteOff ptr off (1 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k + (STMaybe _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STMaybe t, Just x) -> do + pokeByteOff ptr off (1 :: Word8) + serialise t x ptr (off + alignmentSTy t) k + (STArr n t, Array sh vec) -> do + pokeShape ptr off n sh + let off1 = off + 8 * fromSNat n + eltsz = sizeofSTy t + allocaBytes (shapeSize sh * sizeofSTy t) $ \arrptr -> + let loop i + | i == shapeSize sh = k + | otherwise = + serialise t (vec V.! i) arrptr (off1 + i * eltsz) $ + loop (i+1) + in loop 0 + (STScal sty, x) -> case sty of + STI32 -> pokeByteOff ptr off (x :: Int32) >> k + STI64 -> pokeByteOff ptr off (x :: Int64) >> k + STF32 -> pokeByteOff ptr off (x :: Float) >> k + STF64 -> pokeByteOff ptr off (x :: Double) >> k + STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k + (STAccum{}, _) -> error "Cannot serialise accumulators" + +-- | Assumes that this is called at the correct alignment. +deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) +deserialise topty ptr off = + -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls + case topty of + STNil -> return () + STPair a b -> do + x <- deserialise a ptr off + y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a)) + return (x, y) + STEither a b -> do + tag <- peekByteOff @Word8 ptr off + if tag == 0 -- alignment of (a + b) is alignment of (union {a b}) + then Left <$> deserialise a ptr (off + alignmentSTy topty) + else Right <$> deserialise b ptr (off + alignmentSTy topty) + STMaybe t -> do + tag <- peekByteOff @Word8 ptr off + if tag == 0 + then return Nothing + else Just <$> deserialise t ptr (off + alignmentSTy t) + STArr n t -> do + sh <- peekShape ptr off n + let off1 = off + 8 * fromSNat n + eltsz = sizeofSTy t + Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t ptr (off1 + i * eltsz)) + STScal sty -> case sty of + STI32 -> peekByteOff @Int32 ptr off + STI64 -> peekByteOff @Int64 ptr off + STF32 -> peekByteOff @Float ptr off + STF64 -> peekByteOff @Double ptr off + STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off + STAccum{} -> error "Cannot serialise accumulators" + +align :: Int -> Int -> Int +align a off = (off + a - 1) `div` a * a + +alignmentSTy :: STy t -> Int +alignmentSTy = fst . metricsSTy + +sizeofSTy :: STy t -> Int +sizeofSTy = snd . metricsSTy + +-- | Returns (alignment, sizeof) +metricsSTy :: STy t -> (Int, Int) +metricsSTy STNil = (1, 0) +metricsSTy (STPair a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, align (max a1 a2) (s1 + s2)) +metricsSTy (STEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STMaybe t) = + let (a, s) = metricsSTy t + in (a, a + s) -- the union after the tag byte is aligned +metricsSTy (STArr n _) = (8, fromSNat n * 8 + 8) +metricsSTy (STScal sty) = case sty of + STI32 -> (4, 4) + STI64 -> (8, 8) + STF32 -> (4, 4) + STF64 -> (8, 8) + STBool -> (1, 1) -- compiled to uint8_t +metricsSTy (STAccum t) = metricsSTy t + +pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () +pokeShape ptr off = go . fromSNat + where + go :: Int -> Shape n -> IO () + go rank = \case + ShNil -> return () + sh `ShCons` n -> do + pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64) + go (rank - 1) sh + +peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) +peekShape ptr off = \case + SZ -> return ShNil + SS n -> ShCons <$> peekShape ptr off n + <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs new file mode 100644 index 0000000..163be2b --- /dev/null +++ b/src/Compile/Exec.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} +module Compile.Exec ( + KernelLib, + buildKernel, + callKernelFun, +) where + +import Data.IORef +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Foreign (Ptr) +import Foreign.Ptr (FunPtr) +import System.Directory (removeDirectoryRecursive) +import System.Environment (lookupEnv) +import System.IO (hPutStrLn, stderr) +import System.Posix.DynamicLinker +import System.Posix.Temp (mkdtemp) +import System.Process (readProcess) + + +-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) +data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ())))) + +buildKernel :: String -> [String] -> IO KernelLib +buildKernel csource funnames = do + template <- (++ "/tmp.chad.") <$> getTempDir + path <- mkdtemp template + + let outso = path ++ "/out.so" + let args = ["-O3", "-march=native", "-shared", "-fPIC", "-x", "c", "-o", outso, "-"] + _ <- readProcess "gcc" args csource + + hPutStrLn stderr $ "[chad] loading kernel " ++ path + dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL] + + removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now + + ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames] + ref <- newIORef ptrs + _ <- mkWeakIORef ref (do hPutStrLn stderr $ "[chad] unloading kernel " ++ path + dlclose dl) + return (KernelLib ref) + +foreign import ccall "dynamic" + wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () + +-- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive +{-# NOINLINE callKernelFun #-} +callKernelFun :: String -> KernelLib -> Ptr () -> IO () +callKernelFun key (KernelLib ref) arg = do + mp <- readIORef ref + wrapKernelFun (mp Map.! key) arg + +getTempDir :: IO FilePath +getTempDir = + lookupEnv "TMPDIR" >>= \case + Just s | not (null s) -> return s + _ -> return "/tmp" |