aboutsummaryrefslogtreecommitdiff
path: root/src/Compile.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
committerTom Smeding <tom@tomsmeding.com>2025-10-30 15:56:35 +0100
commit4d456e4d34b1e4fb3725051d1b8a0c376b704692 (patch)
tree1385217efcc0b58ddb028e707e6a5a36b884ed65 /src/Compile.hs
parent0e8e59c5f9af547cf1b79b9bae892e32700ace56 (diff)
Implement reshape
Diffstat (limited to 'src/Compile.hs')
-rw-r--r--src/Compile.hs24
1 files changed, 23 insertions, 1 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 0ab7ea4..4e81c6a 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -920,6 +920,25 @@ compile' env = \case
EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
+ EReshape _ dim esh earg -> do
+ let STArr origDim eltty = typeOf earg
+ strname <- emitStruct (STArr dim eltty)
+
+ shname <- compileAssign "reshsh" env esh
+ arrname <- compileAssign "resharg" env earg
+
+ when emitChecks $ do
+ emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
+ printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
+ printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
+ mempty
+
+ return (CEStruct strname
+ [("buf", CEProj (CELit arrname) "buf")
+ ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+
EConst _ t x -> return $ CELit $ compileScal True t x
EIdx0 _ e -> do
@@ -1323,7 +1342,7 @@ compileShapeQuery (SS n) var =
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
-- | Takes a variable name for the array, not the buffer.
compileArrShapeComponents :: SNat n -> String -> [CExpr]
@@ -1347,6 +1366,9 @@ shapeTupFromLitVars = \n -> go n . reverse
go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+prodExpr :: [CExpr] -> CExpr
+prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
+
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
let unary cop = return @CompM $ CECall cop [e1]