diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-30 15:56:35 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-30 15:56:35 +0100 |
| commit | 4d456e4d34b1e4fb3725051d1b8a0c376b704692 (patch) | |
| tree | 1385217efcc0b58ddb028e707e6a5a36b884ed65 /src/Compile.hs | |
| parent | 0e8e59c5f9af547cf1b79b9bae892e32700ace56 (diff) | |
Implement reshape
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 0ab7ea4..4e81c6a 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -920,6 +920,25 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -1323,7 +1342,7 @@ compileShapeQuery (SS n) var = -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) -- | Takes a variable name for the array, not the buffer. compileArrShapeComponents :: SNat n -> String -> [CExpr] @@ -1347,6 +1366,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] |
