diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 21:35:17 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-04 21:45:50 +0100 | 
| commit | 2cce83edce315a1d04811f50e28d60ec20d70bbc (patch) | |
| tree | 0c21966ac34d5cbdd5fb7ab17e9792d0882b649b | |
| parent | d751deedfdc2ba5fbeb72ede5754587a1f677835 (diff) | |
Compile: maximum1i and minimum1i
| -rw-r--r-- | src/Compile.hs | 55 | 
1 files changed, 47 insertions, 8 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 0e6eee7..1355841 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-} @@ -20,7 +21,7 @@ import Data.Set (Set)  import Data.Some  import qualified Data.Vector as V  import Foreign --- import System.IO (hPutStrLn, stderr) +import System.IO (hPutStrLn, stderr)  import Prelude hiding ((^))  import qualified Prelude @@ -44,11 +45,14 @@ let array = arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegr  -- In shape and index arrays, the innermost dimension is on the right (last index). +debug :: Bool +debug = toEnum 0 +  compile :: SList STy env -> Ex env t          -> IO (SList Value env -> IO (Rep t))  compile = \env expr -> do    let source = compileToString env expr -  -- hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" +  when debug $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>"    lib <- buildKernel source ["kernel"]    let arg_metrics = reverse (unSList metricsSTy env) @@ -306,7 +310,8 @@ compileToString env expr =        (arg_offsets, result_offset') = computeStructOffsets arg_metrics        result_offset = align (alignmentSTy (typeOf expr)) result_offset'    in ($ "") $ compose -       [showString "#include <stdint.h>\n" +       [showString "#include <stdio.h>\n" +       ,showString "#include <stdint.h>\n"         ,showString "#include <stdlib.h>\n"         ,showString "#include <math.h>\n\n"         ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs @@ -320,7 +325,7 @@ compileToString env expr =         ,showString ("  return ") . printCExpr 0 res . showString ";\n}\n\n"         ,showString "void kernel(void *data) {\n"          -- Some code here assumes that we're on a 64-bit system, so let's check that -       ,showString "  if (sizeof(void*) != 8 || sizeof(size_t) != 8) { abort(); }\n" +       ,showString "  if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); abort(); }\n"         ,showString $ "  *(" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ") = typed_kernel(" ++                         concat (map (\((arg, typ), off, idx) ->                                         "\n    *(" ++ typ ++ "*)(data + " ++ show off ++ ")" @@ -687,9 +692,9 @@ compile' env = \case      return (CELit resname) -  -- EMaximum1Inner _ e -> error "TODO" -- EMaximum1Inner ext (compile' e) +  EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e -  -- EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e) +  EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e    EConst _ t x -> return $ CELit $ compileScal True t x @@ -754,8 +759,6 @@ compile' env = \case    EOneHot{} -> error "Compile: monoid operations should have been eliminated"    EFold1Inner{} -> error "Compile: not implemented: EFold1Inner" -  EMaximum1Inner{} -> error "Compile: not implemented: EMaximum1Inner" -  EMinimum1Inner{} -> error "Compile: not implemented: EMinimum1Inner"    EIdx1{} -> error "Compile: not implemented: EIdx1"    ECustom{} -> error "Compile: not implemented: ECustom"    EWith{} -> error "Compile: not implemented: EWith" @@ -940,6 +943,42 @@ compileScal pedantic typ x = case typ of    STF64 -> show x    STBool -> if x then "1" else "0" +compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr +compileExtremum nameBase opName operator env e = do +  let STArr (SS n) t = typeOf e +  e' <- compile' env e +  argname <- genName' (nameBase ++ "arg") +  emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e' + +  shszname <- genName' "shsz" +  -- This n is one less than the shape of the thing we're querying, which is +  -- unexpected. But it's exactly what we want, so we do it anyway. +  emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + +  resname <- allocArray (nameBase ++ "res") n t (CELit shszname) +                [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + +  lenname <- genName' "n" +  emit $ SVarDecl True (repSTy tIx) lenname +                  (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + +  emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); abort(); }" + +  ivar <- genName' "i" +  jvar <- genName' "j" +  xvar <- genName +  redvar <- genName' "red"  -- use "red", not "acc", to avoid confusion with accumulators +  emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList +           -- we have ScalIsNumeric, so it has 1 and (<) etc. in C +           [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ "]")) +           ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList +              [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]")) +              ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar) +              ] +           ,SAsg (resname ++ ".buf->a[" ++ ivar ++ "]") (CELit redvar)] + +  return (CELit resname) +  compose :: Foldable t => t (a -> a) -> a -> a  compose = foldr (.) id  | 
