diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 21:35:17 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 21:45:50 +0100 |
commit | 2cce83edce315a1d04811f50e28d60ec20d70bbc (patch) | |
tree | 0c21966ac34d5cbdd5fb7ab17e9792d0882b649b | |
parent | d751deedfdc2ba5fbeb72ede5754587a1f677835 (diff) |
Compile: maximum1i and minimum1i
-rw-r--r-- | src/Compile.hs | 55 |
1 files changed, 47 insertions, 8 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 0e6eee7..1355841 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} @@ -20,7 +21,7 @@ import Data.Set (Set) import Data.Some import qualified Data.Vector as V import Foreign --- import System.IO (hPutStrLn, stderr) +import System.IO (hPutStrLn, stderr) import Prelude hiding ((^)) import qualified Prelude @@ -44,11 +45,14 @@ let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegr -- In shape and index arrays, the innermost dimension is on the right (last index). +debug :: Bool +debug = toEnum 0 + compile :: SList STy env -> Ex env t -> IO (SList Value env -> IO (Rep t)) compile = \env expr -> do let source = compileToString env expr - -- hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" + when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" lib <- buildKernel source ["kernel"] let arg_metrics = reverse (unSList metricsSTy env) @@ -306,7 +310,8 @@ compileToString env expr = (arg_offsets, result_offset') = computeStructOffsets arg_metrics result_offset = align (alignmentSTy (typeOf expr)) result_offset' in ($ "") $ compose - [showString "#include <stdint.h>\n" + [showString "#include <stdio.h>\n" + ,showString "#include <stdint.h>\n" ,showString "#include <stdlib.h>\n" ,showString "#include <math.h>\n\n" ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs @@ -320,7 +325,7 @@ compileToString env expr = ,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n" ,showString "void kernel(void *data) {\n" -- Some code here assumes that we're on a 64-bit system, so let's check that - ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { abort(); }\n" + ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n" ,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++ concat (map (\((arg, typ), off, idx) -> "\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" @@ -687,9 +692,9 @@ compile' env = \case return (CELit resname) - -- EMaximum1Inner _ e -> error "TODO" -- EMaximum1Inner ext (compile' e) + EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e - -- EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e) + EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e EConst _ t x -> return $ CELit $ compileScal True t x @@ -754,8 +759,6 @@ compile' env = \case EOneHot{} -> error "Compile: monoid operations should have been eliminated" EFold1Inner{} -> error "Compile: not implemented: EFold1Inner" - EMaximum1Inner{} -> error "Compile: not implemented: EMaximum1Inner" - EMinimum1Inner{} -> error "Compile: not implemented: EMinimum1Inner" EIdx1{} -> error "Compile: not implemented: EIdx1" ECustom{} -> error "Compile: not implemented: ECustom" EWith{} -> error "Compile: not implemented: EWith" @@ -940,6 +943,42 @@ compileScal pedantic typ x = case typ of STF64 -> show x STBool -> if x then "1" else "0" +compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr +compileExtremum nameBase opName operator env e = do + let STArr (SS n) t = typeOf e + e' <- compile' env e + argname <- genName' (nameBase ++ "arg") + emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e' + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray (nameBase ++ "res") n t (CELit shszname) + [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + + emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }" + + ivar <- genName' "i" + jvar <- genName' "j" + xvar <- genName + redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList + -- we have ScalIsNumeric, so it has 1 and (<) etc. in C + [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ "]")) + ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList + [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]")) + ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar) + ] + ,SAsg (resname ++ ".buf->a[" ++ ivar ++ "]") (CELit redvar)] + + return (CELit resname) + compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id |