summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-04 21:35:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-04 21:45:50 +0100
commit2cce83edce315a1d04811f50e28d60ec20d70bbc (patch)
tree0c21966ac34d5cbdd5fb7ab17e9792d0882b649b
parentd751deedfdc2ba5fbeb72ede5754587a1f677835 (diff)
Compile: maximum1i and minimum1i
-rw-r--r--src/Compile.hs55
1 files changed, 47 insertions, 8 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 0e6eee7..1355841 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
@@ -20,7 +21,7 @@ import Data.Set (Set)
import Data.Some
import qualified Data.Vector as V
import Foreign
--- import System.IO (hPutStrLn, stderr)
+import System.IO (hPutStrLn, stderr)
import Prelude hiding ((^))
import qualified Prelude
@@ -44,11 +45,14 @@ let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegr
-- In shape and index arrays, the innermost dimension is on the right (last index).
+debug :: Bool
+debug = toEnum 0
+
compile :: SList STy env -> Ex env t
-> IO (SList Value env -> IO (Rep t))
compile = \env expr -> do
let source = compileToString env expr
- -- hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
+ when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"
lib <- buildKernel source ["kernel"]
let arg_metrics = reverse (unSList metricsSTy env)
@@ -306,7 +310,8 @@ compileToString env expr =
(arg_offsets, result_offset') = computeStructOffsets arg_metrics
result_offset = align (alignmentSTy (typeOf expr)) result_offset'
in ($ "") $ compose
- [showString "#include <stdint.h>\n"
+ [showString "#include <stdio.h>\n"
+ ,showString "#include <stdint.h>\n"
,showString "#include <stdlib.h>\n"
,showString "#include <math.h>\n\n"
,compose $ map (\sd -> printStructDecl sd . showString "\n") structs
@@ -320,7 +325,7 @@ compileToString env expr =
,showString (" return ") . printCExpr 0 res . showString ";\n}\n\n"
,showString "void kernel(void *data) {\n"
-- Some code here assumes that we're on a 64-bit system, so let's check that
- ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { abort(); }\n"
+ ,showString " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"
,showString $ " *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++
concat (map (\((arg, typ), off, idx) ->
"\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")"
@@ -687,9 +692,9 @@ compile' env = \case
return (CELit resname)
- -- EMaximum1Inner _ e -> error "TODO" -- EMaximum1Inner ext (compile' e)
+ EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e
- -- EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e)
+ EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
EConst _ t x -> return $ CELit $ compileScal True t x
@@ -754,8 +759,6 @@ compile' env = \case
EOneHot{} -> error "Compile: monoid operations should have been eliminated"
EFold1Inner{} -> error "Compile: not implemented: EFold1Inner"
- EMaximum1Inner{} -> error "Compile: not implemented: EMaximum1Inner"
- EMinimum1Inner{} -> error "Compile: not implemented: EMinimum1Inner"
EIdx1{} -> error "Compile: not implemented: EIdx1"
ECustom{} -> error "Compile: not implemented: ECustom"
EWith{} -> error "Compile: not implemented: EWith"
@@ -940,6 +943,42 @@ compileScal pedantic typ x = case typ of
STF64 -> show x
STBool -> if x then "1" else "0"
+compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr
+compileExtremum nameBase opName operator env e = do
+ let STArr (SS n) t = typeOf e
+ e' <- compile' env e
+ argname <- genName' (nameBase ++ "arg")
+ emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e'
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, which is
+ -- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray (nameBase ++ "res") n t (CELit shszname)
+ [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+
+ emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }"
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ xvar <- genName
+ redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
+ -- we have ScalIsNumeric, so it has 1 and (<) etc. in C
+ [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ "]"))
+ ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList
+ [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]"))
+ ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar)
+ ]
+ ,SAsg (resname ++ ".buf->a[" ++ ivar ++ "]") (CELit redvar)]
+
+ return (CELit resname)
+
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id