summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-05 21:30:12 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-05 21:30:12 +0100
commitbb88a7dfdb377c77264801db01d72f2e8b245199 (patch)
treedebb1bff3ffcd2970c168e8fd22f5a1e3b5651b9
parent76f047376405d97b113573db8b6997088e9b9383 (diff)
Compile: Implement EWith (TODO EAccum)
That's going to be a mess
-rw-r--r--src/AST/Types.hs10
-rw-r--r--src/CHAD/Accum.hs9
-rw-r--r--src/Compile.hs137
3 files changed, 143 insertions, 13 deletions
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index adcc760..acf7053 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -107,3 +107,13 @@ type family ScalIsIntegral t where
ScalIsIntegral TF32 = False
ScalIsIntegral TF64 = False
ScalIsIntegral TBool = False
+
+-- | Returns true for arrays /and/ accumulators;
+hasArrays :: STy t' -> Bool
+hasArrays STNil = False
+hasArrays (STPair a b) = hasArrays a || hasArrays b
+hasArrays (STEither a b) = hasArrays a || hasArrays b
+hasArrays (STMaybe t) = hasArrays t
+hasArrays STArr{} = True
+hasArrays STScal{} = False
+hasArrays STAccum{} = True
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index 659c45f..14a1d3b 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -8,15 +8,6 @@ import Data
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = error "Accumulators not allowed in source program"
-
makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators SNil e = e
makeAccumulators (t `SCons` envpro) e =
diff --git a/src/Compile.hs b/src/Compile.hs
index 037b0d8..424b28d 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -16,6 +16,7 @@ import qualified Data.Functor.Product as Product
import Data.Functor.Product (Product)
import Data.List (foldl1', intersperse, intercalate)
import qualified Data.Map.Strict as Map
+import Data.Maybe (fromMaybe)
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Some
@@ -745,9 +746,29 @@ compile' env = \case
e' <- compile' env e
compileOpGeneral op e'
- -- ECustom _ t1 t2 t3 a b c e1 e2 -> error "TODO" -- ECustom ext t1 t2 t3 (compile' a) (compile' b) (compile' c) (compile' e1) (compile' e2)
+ ECustom _ t1 t2 _ earg _ _ e1 e2 -> do
+ e1' <- compile' env e1
+ name1 <- genName
+ emit $ SVarDecl True (repSTy t1) name1 e1'
+ e2' <- compile' env e2
+ name2 <- genName
+ emit $ SVarDecl True (repSTy t2) name2 e2'
+ compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg
+
+ EWith _ e1 e2 -> do
+ let t = typeOf e1
+
+ e1' <- compile' env e1
+ name1 <- genName
+ emit $ SVarDecl True (repSTy t) name1 e1'
+
+ mcopy <- copyForWriting t name1
+ accname <- genName' "accum"
+ emit $ SVarDecl False (repSTy (STAccum t)) accname (maybe (CELit name1) id mcopy)
+
+ e2' <- compile' (Const accname `SCons` env) e2
- -- EWith _ a b -> error "TODO" -- EWith (compile' a) (compile' b)
+ return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)]
-- EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e)
@@ -766,8 +787,6 @@ compile' env = \case
EFold1Inner{} -> error "Compile: not implemented: EFold1Inner"
EIdx1{} -> error "Compile: not implemented: EIdx1"
- ECustom{} -> error "Compile: not implemented: ECustom"
- EWith{} -> error "Compile: not implemented: EWith"
EAccum{} -> error "Compile: not implemented: EAccum"
data Increment = Increment | Decrement
@@ -985,8 +1004,118 @@ compileExtremum nameBase opName operator env e = do
return (CELit resname)
+-- | If this returns Nothing, there was nothing to copy because making a simple
+-- value copy in C already makes it suitable to write to.
+copyForWriting :: STy t -> String -> CompM (Maybe CExpr)
+copyForWriting topty var = case topty of
+ STNil -> return Nothing
+
+ STPair a b -> do
+ e1 <- copyForWriting a (var ++ ".a")
+ e2 <- copyForWriting b (var ++ ".b")
+ case (e1, e2) of
+ (Nothing, Nothing) -> return Nothing
+ _ -> return $ Just $ CEStruct (repSTy topty)
+ [("a", fromMaybe (CELit (var++".a")) e1)
+ ,("b", fromMaybe (CELit (var++".b")) e2)]
+
+ STEither a b -> do
+ (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l")
+ (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r")
+ case (e1, e2) of
+ (Nothing, Nothing) -> return Nothing
+ _ -> do
+ name <- genName
+ emit $ SVarDeclUninit (repSTy topty) name
+ emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
+ (BList stmts1
+ <> pure (SAsg name (CEStruct (repSTy topty)
+ [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)])))
+ (BList stmts2
+ <> pure (SAsg name (CEStruct (repSTy topty)
+ [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)])))
+ return (Just (CELit name))
+
+ STMaybe t -> do
+ (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j")
+ case e1 of
+ Nothing -> return Nothing
+ Just e1' -> do
+ name <- genName
+ emit $ SVarDeclUninit (repSTy topty) name
+ emit $ SIf (CEBinop (CELit var) "==" (CELit "0"))
+ (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")])))
+ (BList stmts1
+ <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')])))
+ return (Just (CELit name))
+
+ -- If there are no nested arrays, we know that a refcount of 1 means that the
+ -- whole thing is owned. Nested arrays have their own refcount, so with
+ -- nesting we'd have to check the refcounts of all the nested arrays _too_;
+ -- at that point we might as well copy the whole thing. Furthermore, no
+ -- sub-arrays means that the whole thing is flat, and we can just memcpy if
+ -- necessary.
+ STArr n t | not (hasArrays t) -> do
+ name <- genName
+ shszname <- genName' "shsz"
+ emit $ SVarDeclUninit (repSTy (STArr n t)) name
+
+ emit $ SIf (CEBinop (CELit (var ++ ".refc")) "==" (CELit "1"))
+ (pure (SAsg name (CELit var)))
+ (let shbytes = fromSNat n * 8
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
+ in BList
+ [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
+ ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])])
+ ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
+ show shbytes ++ ");"
+ ,SAsg (name ++ ".buf->refc") (CELit "1")
+ ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++
+ printCExpr 0 databytes ")"])
+ return (Just (CELit name))
+
+ STArr n t -> do
+ shszname <- genName' "shsz"
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var)
+
+ let shbytes = fromSNat n * 8
+ databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t)))
+ totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes
+
+ name <- genName
+ emit $ SVarDecl False (repSTy (STArr n t)) name
+ (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc" [totalbytes])])
+ emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++
+ show shbytes ++ ");"
+ emit $ SAsg (name ++ ".buf->refc") (CELit "1")
+
+ -- put the arrays in variables to cut short the not-quite-var chain
+ dstvar <- genName' "cpydst"
+ emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs"))
+ srcvar <- genName' "cpysrc"
+ emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs"))
+
+ ivar <- genName' "i"
+
+ (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]")
+ let cpye' = case cpye of
+ Just e -> e
+ Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug"
+
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
+ BList cpystmts
+ <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye')
+
+ return (Just (CELit name))
+
+ STScal _ -> return Nothing
+
+ STAccum _ -> error "Compile: Nested accumulators not supported"
+
compose :: Foldable t => t (a -> a) -> a -> a
compose = foldr (.) id
+-- | Type-restricted.
(^) :: Num a => a -> Int -> a
(^) = (Prelude.^)