diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-21 23:18:05 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-21 23:18:05 +0100 | 
| commit | f3ef5df26404225ceb316ba626a94cbef4426f5e (patch) | |
| tree | c91c5690e6030e6e8873b41412d9dade241a6457 | |
| parent | f87bcb545ce7aae62a1121665a7050154858c75d (diff) | |
Compile: First compilation of fold1i
| -rw-r--r-- | src/Compile.hs | 66 | 
1 files changed, 54 insertions, 12 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index b9dbd41..a3b4be1 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -323,6 +323,7 @@ compileToString env expr =         [showString "#include <stdio.h>\n"         ,showString "#include <stdint.h>\n"         ,showString "#include <stdlib.h>\n" +       ,showString "#include <string.h>\n"         ,showString "#include <math.h>\n\n"         ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs         ,showString "\n" @@ -632,15 +633,56 @@ compile' env = \case      return (CELit arrname) -  -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) +  EFold1Inner _ commut efun ex0 earr -> do +    let STArr (SS n) t = typeOf earr +    x0name <- compileAssign "foldx0" env ex0 +    arrname <- compileAssign "foldarr" env earr + +    shszname <- genName' "shsz" +    -- This n is one less than the shape of the thing we're querying, which is +    -- unexpected. But it's exactly what we want, so we do it anyway. +    emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) + +    resname <- allocArray "foldres" n t (CELit shszname) +                  [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + +    lenname <- genName' "n" +    emit $ SVarDecl True (repSTy tIx) lenname +                    (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + +    ((), x0incrStmts) <- scope $ incrementVarAlways Increment t x0name + +    ivar <- genName' "i" +    jvar <- genName' "j" +    accvar <- genName' "tot" +    let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" +    (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun +    ((), arreltIncrStmts) <- scope $ incrementVarAlways Increment t arreltlit + +    emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ +             pure (SVarDecl False (repSTy t) accvar (CELit x0name)) +             <> BList x0incrStmts  -- we're copying x0 here +             <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ +                        -- The combination function will consume the array element +                        -- and the accumulator. The accumulator is replaced by +                        -- what comes out of the function anyway, so that's +                        -- fine, but we do need to increment the array element. +                        BList arreltIncrStmts +                        <> BList funStmts +                        <> pure (SAsg accvar funres)) +             <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + +    incrementVarAlways Decrement t x0name +    incrementVarAlways Decrement (typeOf earr) arrname + +    return (CELit resname)    ESum1Inner _ e -> do      let STArr (SS n) t = typeOf e      argname <- compileAssign "sumarg" env e      shszname <- genName' "shsz" -    -- This n is one less than the shape of the thing we're querying, which is -    -- unexpected. But it's exactly what we want, so we do it anyway. +    -- This n is one less than the shape of the thing we're querying, like EFold1Inner.      emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)      resname <- allocArray "sumres" n t (CELit shszname) @@ -758,15 +800,17 @@ compile' env = \case          return (CELit name)    EWith _ t e1 e2 -> do +    actyname <- emitStruct (STAccum t)      name1 <- compileAssign "" env e1      mcopy <- copyForWriting t name1      accname <- genName' "accum" -    emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy) +    emit $ SVarDecl False actyname accname (CEStruct actyname [("ac", maybe (CELit name1) id mcopy)])      e2' <- compile' (Const accname `SCons` env) e2 -    return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)] +    rettyname <- emitStruct (STPair (typeOf e2) t) +    return $ CEStruct rettyname [("a", e2'), ("b", CEProj (CELit accname) "ac")]    EAccum _ t prj eidx eval eacc -> do      nameidx <- compileAssign "acidx" env eidx @@ -826,7 +870,7 @@ compile' env = \case          escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]                                        | ord c < 32      -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")                                        | otherwise       -> [c] -    emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ escape s ++ "); exit(1);" +    emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); exit(1);"      case t of        STScal _ -> return (CELit "0")        _ -> do @@ -837,7 +881,6 @@ compile' env = \case    EPlus{} -> error "Compile: monoid operations should have been eliminated"    EOneHot{} -> error "Compile: monoid operations should have been eliminated" -  EFold1Inner{} -> error "Compile: not implemented: EFold1Inner"    EIdx1{} -> error "Compile: not implemented: EIdx1"  compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String @@ -1121,15 +1164,14 @@ copyForWriting topty var = case topty of    -- If there are no nested arrays, we know that a refcount of 1 means that the    -- whole thing is owned. Nested arrays have their own refcount, so with    -- nesting we'd have to check the refcounts of all the nested arrays _too_; -  -- at that point we might as well copy the whole thing. Furthermore, no -  -- sub-arrays means that the whole thing is flat, and we can just memcpy if -  -- necessary. +  -- let's not do that. Furthermore, no sub-arrays means that the whole thing +  -- is flat, and we can just memcpy if necessary.    STArr n t | not (hasArrays t) -> do      name <- genName      shszname <- genName' "shsz"      emit $ SVarDeclUninit (repSTy (STArr n t)) name -    emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1")) +    emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1"))               (pure (SAsg name (CELit var)))               (let shbytes = fromSNat n * 8                    databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) @@ -1141,7 +1183,7 @@ copyForWriting topty var = case topty of                                        show shbytes ++ ");"                 ,SAsg (name ++ ".buf->refc") (CELit "1")                 ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ -                                      printCExpr 0 databytes ")"]) +                                      printCExpr 0 databytes ");"])      return (Just (CELit name))    STArr n t -> do | 
