summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:33:09 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-03 17:33:09 +0100
commite34869318cd37fa73c12291141a5fea29248aede (patch)
tree1ba605bb8101d06c7f97e0a0753d62aafeeecec3
parent15cdfa67937ce20fc90ade59437be0a9e4d7a481 (diff)
Compile: sum1inner, unit, idx0
-rw-r--r--src/Compile.hs104
1 files changed, 81 insertions, 23 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 8c77cd6..4c07f3a 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -13,7 +13,7 @@ import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
import Data.Functor.Product (Product)
-import Data.List (intersperse, intercalate)
+import Data.List (foldl1', intersperse, intercalate)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Set (Set)
@@ -90,6 +90,7 @@ data CExpr
= CELit String -- ^ inserted as-is, assumed no parentheses needed
| CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
| CEProj CExpr String -- ^ field projection: expr.field
+ | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field
| CEAddrOf CExpr -- ^ &expr
| CEIndex CExpr CExpr -- ^ expr[expr]
| CECall String [CExpr] -- ^ function(arg1, ..., argn)
@@ -143,6 +144,7 @@ printCExpr d = \case
| (n, e) <- pairs])
. showString "}"
CEProj e name -> printCExpr 99 e . showString ("." ++ name)
+ CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name)
CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"
CECall n es ->
@@ -594,26 +596,12 @@ compile' env = \case
return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])
EBuild _ n esh efun -> do
- let arrty = STArr n (typeOf efun)
- strname <- emitStruct arrty
-
shname <- genName' "sh"
emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh
shsizename <- genName' "shsz"
emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname)
- bufpname <- genName' "bufp"
- emit $ SVarDecl True (repSTy arrty ++ "_buf *") bufpname
- (CECall "malloc" [CEBinop (CELit (show (fromSNat n * 8 + 8)))
- "+"
- (CEBinop (CELit shsizename)
- "*" (CELit (show (sizeofSTy (typeOf efun)))))])
- forM_ (zip (compileShapeTupIntoArray n shname) [0::Int ..]) $ \(rhs, i) ->
- emit $ SAsg (bufpname ++ "->sh[" ++ show i ++ "]") rhs
- emit $ SAsg (bufpname ++ "->refc") (CELit "1")
-
- arrname <- genName' "arr"
- emit $ SVarDecl True (repSTy arrty) arrname (CEStruct strname [("buf", CELit bufpname)])
+ arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname)
idxargname <- genName' "ix"
(funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun
@@ -623,20 +611,57 @@ compile' env = \case
emit $ SBlock $
pure (SVarDecl False "size_t" linivar (CELit "0"))
<> compose [pure . SLoop (repSTy tIx) ivar (CELit "0")
- (CECast (repSTy tIx) (CEIndex (CELit (bufpname ++ "->sh")) (CELit (show dimidx))))
+ (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx))))
| (ivar, dimidx) <- zip ivars [0::Int ..]]
(pure (SVarDecl True (repSTy (typeOf esh)) idxargname
(shapeTupFromLitVars n ivars))
<> BList funstmts
- <> pure (SAsg (bufpname ++ "->a[" ++ linivar ++ "++]") funretval))
+ <> pure (SAsg (arrname ++ ".buf->a[" ++ linivar ++ "++]") funretval))
return (CELit arrname)
-- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c)
- -- ESum1Inner _ e -> error "TODO" -- ESum1Inner ext (compile' e)
+ ESum1Inner _ e -> do
+ let STArr (SS n) t = typeOf e
+ e' <- compile' env e
+ argname <- genName' "sumarg"
+ emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e'
+
+ shszname <- genName' "shsz"
+ -- This n is one less than the shape of the thing we're querying, which is
+ -- unexpected. But it's exactly what we want, so we do it anyway.
+ emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname)
+
+ resname <- allocArray "sumres" n t (CELit shszname)
+ [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]]
+
+ lenname <- genName' "n"
+ emit $ SVarDecl True (repSTy tIx) lenname
+ (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))
+
+ ivar <- genName' "i"
+ jvar <- genName' "j"
+ accvar <- genName' "tot"
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList
+ -- we have ScalIsNumeric, so it has 0 and (+) in C
+ [SVarDecl False (repSTy t) accvar (CELit "0")
+ ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $
+ pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];"
+ ,SAsg (resname ++ ".buf->a[" ++ ivar ++ "]") (CELit accvar)]
- -- EUnit _ e -> error "TODO" -- EUnit ext (compile' e)
+ return (CELit resname)
+
+ EUnit _ e -> do
+ e' <- compile' env e
+ let typ = STArr SZ (typeOf e)
+ strname <- emitStruct typ
+ name <- genName
+ emit $ SVarDecl True strname name (CEStruct strname
+ [("buf", CECall "malloc" [CELit (show (8 + sizeofSTy (typeOf e)))])])
+ emit $ SAsg (name ++ ".buf->refc") (CELit "1")
+ emit $ SAsg (name ++ ".buf->a[0]") e'
+ return (CELit name)
-- EReplicate1Inner _ a b -> error "TODO" -- EReplicate1Inner ext (compile' a) (compile' b)
@@ -646,7 +671,16 @@ compile' env = \case
EConst _ t x -> return $ CELit $ compileScal True t x
- -- EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e)
+ EIdx0 _ e -> do
+ let STArr _ t = typeOf e
+ e' <- compile' env e
+ arrname <- genName
+ emit $ SVarDecl True (repSTy (STArr SZ t)) arrname e'
+ name <- genName
+ emit $ SVarDecl True (repSTy t) name
+ (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "a") (CELit "0"))
+ incrementVarAlways Decrement (STArr SZ t) arrname
+ return (CELit name)
-- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
@@ -655,7 +689,7 @@ compile' env = \case
arrname <- genName' "ixarr"
idxname <- genName' "ixix"
emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr
- emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx
+ when (fromSNat n > 0) $ emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx
resname <- genName' "ixres"
emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname))
incrementVarAlways Decrement (STArr n t) arrname
@@ -771,6 +805,23 @@ toLinearIdx (SS n) arrvar idxvar =
-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
-- _
+-- | The shape must have the outer dimension at the head (and the inner dimension on the right).
+allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String
+allocArray nameBase rank eltty shsz shape = do
+ when (length shape /= fromSNat rank) $
+ error "allocArray: shape does not match rank"
+ let arrty = STArr rank eltty
+ strname <- emitStruct arrty
+ arrname <- genName' nameBase
+ emit $ SVarDecl True strname arrname $ CEStruct strname
+ [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8)))
+ "+"
+ (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])]
+ forM_ (zip shape [0::Int ..]) $ \(dim, i) ->
+ emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim
+ emit $ SAsg (arrname ++ ".buf->refc") (CELit "1")
+ return arrname
+
compileShapeQuery :: SNat n -> String -> CExpr
compileShapeQuery SZ _ = CEStruct (repSTy STNil) []
compileShapeQuery (SS n) var =
@@ -779,10 +830,17 @@ compileShapeQuery (SS n) var =
,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]
compileShapeSize :: SNat n -> String -> CExpr
-compileShapeSize SZ _ = CELit "0"
+compileShapeSize SZ _ = CELit "1"
compileShapeSize (SS SZ) var = CELit (var ++ ".b")
compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b"))
+-- | Takes a variable name for the array, not the buffer.
+compileArrShapeSize :: SNat n -> String -> CExpr
+compileArrShapeSize SZ _ = CELit "1"
+compileArrShapeSize n var =
+ foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]")
+ | i <- [0 .. fromSNat n - 1]]
+
compileShapeTupIntoArray :: SNat n -> String -> [CExpr]
compileShapeTupIntoArray = \n var -> map CELit (toList (go n var))
where