diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 21:30:12 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 21:30:12 +0100 | 
| commit | bb88a7dfdb377c77264801db01d72f2e8b245199 (patch) | |
| tree | debb1bff3ffcd2970c168e8fd22f5a1e3b5651b9 /src/Compile.hs | |
| parent | 76f047376405d97b113573db8b6997088e9b9383 (diff) | |
Compile: Implement EWith (TODO EAccum)
That's going to be a mess
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 137 | 
1 files changed, 133 insertions, 4 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 037b0d8..424b28d 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -16,6 +16,7 @@ import qualified Data.Functor.Product as Product  import Data.Functor.Product (Product)  import Data.List (foldl1', intersperse, intercalate)  import qualified Data.Map.Strict as Map +import Data.Maybe (fromMaybe)  import qualified Data.Set as Set  import Data.Set (Set)  import Data.Some @@ -745,9 +746,29 @@ compile' env = \case      e' <- compile' env e      compileOpGeneral op e' -  -- ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) +  ECustom _ t1 t2 _ earg _ _ e1 e2 -> do +    e1' <- compile' env e1 +    name1 <- genName +    emit $ SVarDecl True (repSTy t1) name1 e1' +    e2' <- compile' env e2 +    name2 <- genName +    emit $ SVarDecl True (repSTy t2) name2 e2' +    compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg + +  EWith _ e1 e2 -> do +    let t = typeOf e1 + +    e1' <- compile' env e1 +    name1 <- genName +    emit $ SVarDecl True (repSTy t) name1 e1' -  -- EWith _ a b -> error "TODO" -- EWith (compile' a) (compile' b) +    mcopy <- copyForWriting t name1 +    accname <- genName' "accum" +    emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy) + +    e2' <- compile' (Const accname `SCons` env) e2 + +    return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)]    -- EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) @@ -766,8 +787,6 @@ compile' env = \case    EFold1Inner{} -> error "Compile: not implemented: EFold1Inner"    EIdx1{} -> error "Compile: not implemented: EIdx1" -  ECustom{} -> error "Compile: not implemented: ECustom" -  EWith{} -> error "Compile: not implemented: EWith"    EAccum{} -> error "Compile: not implemented: EAccum"  data Increment = Increment | Decrement @@ -985,8 +1004,118 @@ compileExtremum nameBase opName operator env e = do    return (CELit resname) +-- | If this returns Nothing, there was nothing to copy because making a simple +-- value copy in C already makes it suitable to write to. +copyForWriting :: STy t -> String -> CompM (Maybe CExpr) +copyForWriting topty var = case topty of +  STNil -> return Nothing + +  STPair a b -> do +    e1 <- copyForWriting a (var ++ ".a") +    e2 <- copyForWriting b (var ++ ".b") +    case (e1, e2) of +      (Nothing, Nothing) -> return Nothing +      _ -> return $ Just $ CEStruct (repSTy topty) +                             [("a", fromMaybe (CELit (var++".a")) e1) +                             ,("b", fromMaybe (CELit (var++".b")) e2)] + +  STEither a b -> do +    (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") +    (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") +    case (e1, e2) of +      (Nothing, Nothing) -> return Nothing +      _ -> do +        name <- genName +        emit $ SVarDeclUninit (repSTy topty) name +        emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) +                 (BList stmts1 +                  <> pure (SAsg name (CEStruct (repSTy topty) +                                        [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) +                 (BList stmts2 +                  <> pure (SAsg name (CEStruct (repSTy topty) +                                        [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) +        return (Just (CELit name)) + +  STMaybe t -> do +    (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") +    case e1 of +      Nothing -> return Nothing +      Just e1' -> do +        name <- genName +        emit $ SVarDeclUninit (repSTy topty) name +        emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) +                 (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) +                 (BList stmts1 +                  <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) +        return (Just (CELit name)) + +  -- If there are no nested arrays, we know that a refcount of 1 means that the +  -- whole thing is owned. Nested arrays have their own refcount, so with +  -- nesting we'd have to check the refcounts of all the nested arrays _too_; +  -- at that point we might as well copy the whole thing. Furthermore, no +  -- sub-arrays means that the whole thing is flat, and we can just memcpy if +  -- necessary. +  STArr n t | not (hasArrays t) -> do +    name <- genName +    shszname <- genName' "shsz" +    emit $ SVarDeclUninit (repSTy (STArr n t)) name + +    emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1")) +             (pure (SAsg name (CELit var))) +             (let shbytes = fromSNat n * 8 +                  databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) +                  totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes +              in BList +               [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) +               ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])]) +               ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ +                                      show shbytes ++ ");" +               ,SAsg (name ++ ".buf->refc") (CELit "1") +               ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ +                                      printCExpr 0 databytes ")"]) +    return (Just (CELit name)) + +  STArr n t -> do +    shszname <- genName' "shsz" +    emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + +    let shbytes = fromSNat n * 8 +        databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) +        totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + +    name <- genName +    emit $ SVarDecl False (repSTy (STArr n t)) name +             (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])]) +    emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ +                                 show shbytes ++ ");" +    emit $ SAsg (name ++ ".buf->refc") (CELit "1") + +    -- put the arrays in variables to cut short the not-quite-var chain +    dstvar <- genName' "cpydst" +    emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs")) +    srcvar <- genName' "cpysrc" +    emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs")) + +    ivar <- genName' "i" + +    (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]") +    let cpye' = case cpye of +                  Just e -> e +                  Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" + +    emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ +             BList cpystmts +             <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') + +    return (Just (CELit name)) + +  STScal _ -> return Nothing + +  STAccum _ -> error "Compile: Nested accumulators not supported" +  compose :: Foldable t => t (a -> a) -> a -> a  compose = foldr (.) id +-- | Type-restricted.  (^) :: Num a => a -> Int -> a  (^) = (Prelude.^) | 
