diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-21 23:18:05 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-21 23:18:05 +0100 |
commit | f3ef5df26404225ceb316ba626a94cbef4426f5e (patch) | |
tree | c91c5690e6030e6e8873b41412d9dade241a6457 | |
parent | f87bcb545ce7aae62a1121665a7050154858c75d (diff) |
Compile: First compilation of fold1i
-rw-r--r-- | src/Compile.hs | 66 |
1 files changed, 54 insertions, 12 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index b9dbd41..a3b4be1 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -323,6 +323,7 @@ compileToString env expr = [showString "#include <stdio.h>\n" ,showString "#include <stdint.h>\n" ,showString "#include <stdlib.h>\n" + ,showString "#include <string.h>\n" ,showString "#include <math.h>\n\n" ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs ,showString "\n" @@ -632,15 +633,56 @@ compile' env = \case return (CELit arrname) - -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) + EFold1Inner _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + x0name <- compileAssign "foldx0" env ex0 + arrname <- compileAssign "foldarr" env earr + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) + + resname <- allocArray "foldres" n t (CELit shszname) + [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + + ((), x0incrStmts) <- scope $ incrementVarAlways Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + accvar <- genName' "tot" + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" + (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun + ((), arreltIncrStmts) <- scope $ incrementVarAlways Increment t arreltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> BList x0incrStmts -- we're copying x0 here + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + BList arreltIncrStmts + <> BList funStmts + <> pure (SAsg accvar funres)) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways Decrement t x0name + incrementVarAlways Decrement (typeOf earr) arrname + + return (CELit resname) ESum1Inner _ e -> do let STArr (SS n) t = typeOf e argname <- compileAssign "sumarg" env e shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, which is - -- unexpected. But it's exactly what we want, so we do it anyway. + -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) resname <- allocArray "sumres" n t (CELit shszname) @@ -758,15 +800,17 @@ compile' env = \case return (CELit name) EWith _ t e1 e2 -> do + actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 mcopy <- copyForWriting t name1 accname <- genName' "accum" - emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy) + emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)]) e2' <- compile' (Const accname `SCons` env) e2 - return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)] + rettyname <- emitStruct (STPair (typeOf e2) t) + return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")] EAccum _ t prj eidx eval eacc -> do nameidx <- compileAssign "acidx" env eidx @@ -826,7 +870,7 @@ compile' env = \case escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") | otherwise -> [c] - emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ escape s ++ "); exit(1);" + emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); exit(1);" case t of STScal _ -> return (CELit "0") _ -> do @@ -837,7 +881,6 @@ compile' env = \case EPlus{} -> error "Compile: monoid operations should have been eliminated" EOneHot{} -> error "Compile: monoid operations should have been eliminated" - EFold1Inner{} -> error "Compile: not implemented: EFold1Inner" EIdx1{} -> error "Compile: not implemented: EIdx1" compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String @@ -1121,15 +1164,14 @@ copyForWriting topty var = case topty of -- If there are no nested arrays, we know that a refcount of 1 means that the -- whole thing is owned. Nested arrays have their own refcount, so with -- nesting we'd have to check the refcounts of all the nested arrays _too_; - -- at that point we might as well copy the whole thing. Furthermore, no - -- sub-arrays means that the whole thing is flat, and we can just memcpy if - -- necessary. + -- let's not do that. Furthermore, no sub-arrays means that the whole thing + -- is flat, and we can just memcpy if necessary. STArr n t | not (hasArrays t) -> do name <- genName shszname <- genName' "shsz" emit $ SVarDeclUninit (repSTy (STArr n t)) name - emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1")) + emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) (pure (SAsg name (CELit var))) (let shbytes = fromSNat n * 8 databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) @@ -1141,7 +1183,7 @@ copyForWriting topty var = case topty of show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ - printCExpr 0 databytes ")"]) + printCExpr 0 databytes ");"]) return (Just (CELit name)) STArr n t -> do |