diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 21:30:12 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-05 21:30:12 +0100 |
commit | bb88a7dfdb377c77264801db01d72f2e8b245199 (patch) | |
tree | debb1bff3ffcd2970c168e8fd22f5a1e3b5651b9 | |
parent | 76f047376405d97b113573db8b6997088e9b9383 (diff) |
Compile: Implement EWith (TODO EAccum)
That's going to be a mess
-rw-r--r-- | src/AST/Types.hs | 10 | ||||
-rw-r--r-- | src/CHAD/Accum.hs | 9 | ||||
-rw-r--r-- | src/Compile.hs | 137 |
3 files changed, 143 insertions, 13 deletions
diff --git a/src/AST/Types.hs b/src/AST/Types.hs index adcc760..acf7053 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -107,3 +107,13 @@ type family ScalIsIntegral t where ScalIsIntegral TF32 = False ScalIsIntegral TF64 = False ScalIsIntegral TBool = False + +-- | Returns true for arrays /and/ accumulators; +hasArrays :: STy t' -> Bool +hasArrays STNil = False +hasArrays (STPair a b) = hasArrays a || hasArrays b +hasArrays (STEither a b) = hasArrays a || hasArrays b +hasArrays (STMaybe t) = hasArrays t +hasArrays STArr{} = True +hasArrays STScal{} = False +hasArrays STAccum{} = True diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index 659c45f..14a1d3b 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -8,15 +8,6 @@ import Data -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = error "Accumulators not allowed in source program" - makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e makeAccumulators (t `SCons` envpro) e = diff --git a/src/Compile.hs b/src/Compile.hs index 037b0d8..424b28d 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -16,6 +16,7 @@ import qualified Data.Functor.Product as Product import Data.Functor.Product (Product) import Data.List (foldl1', intersperse, intercalate) import qualified Data.Map.Strict as Map +import Data.Maybe (fromMaybe) import qualified Data.Set as Set import Data.Set (Set) import Data.Some @@ -745,9 +746,29 @@ compile' env = \case e' <- compile' env e compileOpGeneral op e' - -- ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2) + ECustom _ t1 t2 _ earg _ _ e1 e2 -> do + e1' <- compile' env e1 + name1 <- genName + emit $ SVarDecl True (repSTy t1) name1 e1' + e2' <- compile' env e2 + name2 <- genName + emit $ SVarDecl True (repSTy t2) name2 e2' + compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg + + EWith _ e1 e2 -> do + let t = typeOf e1 + + e1' <- compile' env e1 + name1 <- genName + emit $ SVarDecl True (repSTy t) name1 e1' + + mcopy <- copyForWriting t name1 + accname <- genName' "accum" + emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy) + + e2' <- compile' (Const accname `SCons` env) e2 - -- EWith _ a b -> error "TODO" -- EWith (compile' a) (compile' b) + return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)] -- EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) @@ -766,8 +787,6 @@ compile' env = \case EFold1Inner{} -> error "Compile: not implemented: EFold1Inner" EIdx1{} -> error "Compile: not implemented: EIdx1" - ECustom{} -> error "Compile: not implemented: ECustom" - EWith{} -> error "Compile: not implemented: EWith" EAccum{} -> error "Compile: not implemented: EAccum" data Increment = Increment | Decrement @@ -985,8 +1004,118 @@ compileExtremum nameBase opName operator env e = do return (CELit resname) +-- | If this returns Nothing, there was nothing to copy because making a simple +-- value copy in C already makes it suitable to write to. +copyForWriting :: STy t -> String -> CompM (Maybe CExpr) +copyForWriting topty var = case topty of + STNil -> return Nothing + + STPair a b -> do + e1 <- copyForWriting a (var ++ ".a") + e2 <- copyForWriting b (var ++ ".b") + case (e1, e2) of + (Nothing, Nothing) -> return Nothing + _ -> return $ Just $ CEStruct (repSTy topty) + [("a", fromMaybe (CELit (var++".a")) e1) + ,("b", fromMaybe (CELit (var++".b")) e2)] + + STEither a b -> do + (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") + (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") + case (e1, e2) of + (Nothing, Nothing) -> return Nothing + _ -> do + name <- genName + emit $ SVarDeclUninit (repSTy topty) name + emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) + (BList stmts1 + <> pure (SAsg name (CEStruct (repSTy topty) + [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) + (BList stmts2 + <> pure (SAsg name (CEStruct (repSTy topty) + [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) + return (Just (CELit name)) + + STMaybe t -> do + (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") + case e1 of + Nothing -> return Nothing + Just e1' -> do + name <- genName + emit $ SVarDeclUninit (repSTy topty) name + emit $ SIf (CEBinop (CELit var) "==" (CELit "0")) + (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) + (BList stmts1 + <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) + return (Just (CELit name)) + + -- If there are no nested arrays, we know that a refcount of 1 means that the + -- whole thing is owned. Nested arrays have their own refcount, so with + -- nesting we'd have to check the refcounts of all the nested arrays _too_; + -- at that point we might as well copy the whole thing. Furthermore, no + -- sub-arrays means that the whole thing is flat, and we can just memcpy if + -- necessary. + STArr n t | not (hasArrays t) -> do + name <- genName + shszname <- genName' "shsz" + emit $ SVarDeclUninit (repSTy (STArr n t)) name + + emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1")) + (pure (SAsg name (CELit var))) + (let shbytes = fromSNat n * 8 + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + in BList + [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])]) + ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ + show shbytes ++ ");" + ,SAsg (name ++ ".buf->refc") (CELit "1") + ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ + printCExpr 0 databytes ")"]) + return (Just (CELit name)) + + STArr n t -> do + shszname <- genName' "shsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + + let shbytes = fromSNat n * 8 + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + + name <- genName + emit $ SVarDecl False (repSTy (STArr n t)) name + (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])]) + emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ + show shbytes ++ ");" + emit $ SAsg (name ++ ".buf->refc") (CELit "1") + + -- put the arrays in variables to cut short the not-quite-var chain + dstvar <- genName' "cpydst" + emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs")) + srcvar <- genName' "cpysrc" + emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs")) + + ivar <- genName' "i" + + (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]") + let cpye' = case cpye of + Just e -> e + Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + BList cpystmts + <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') + + return (Just (CELit name)) + + STScal _ -> return Nothing + + STAccum _ -> error "Compile: Nested accumulators not supported" + compose :: Foldable t => t (a -> a) -> a -> a compose = foldr (.) id +-- | Type-restricted. (^) :: Num a => a -> Int -> a (^) = (Prelude.^) |