diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-03 17:33:09 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-03 17:33:09 +0100 |
commit | e34869318cd37fa73c12291141a5fea29248aede (patch) | |
tree | 1ba605bb8101d06c7f97e0a0753d62aafeeecec3 | |
parent | 15cdfa67937ce20fc90ade59437be0a9e4d7a481 (diff) |
Compile: sum1inner, unit, idx0
-rw-r--r-- | src/Compile.hs | 104 |
1 files changed, 81 insertions, 23 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 8c77cd6..4c07f3a 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -13,7 +13,7 @@ import Data.Foldable (toList) import Data.Functor.Const import qualified Data.Functor.Product as Product import Data.Functor.Product (Product) -import Data.List (intersperse, intercalate) +import Data.List (foldl1', intersperse, intercalate) import qualified Data.Map.Strict as Map import qualified Data.Set as Set import Data.Set (Set) @@ -90,6 +90,7 @@ data CExpr = CELit String -- ^ inserted as-is, assumed no parentheses needed | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` | CEProj CExpr String -- ^ field projection: expr.field + | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field | CEAddrOf CExpr -- ^ &expr | CEIndex CExpr CExpr -- ^ expr[expr] | CECall String [CExpr] -- ^ function(arg1, ..., argn) @@ -143,6 +144,7 @@ printCExpr d = \case | (n, e) <- pairs]) . showString "}" CEProj e name -> printCExpr 99 e . showString ("." ++ name) + CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name) CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]" CECall n es -> @@ -594,26 +596,12 @@ compile' env = \case return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) EBuild _ n esh efun -> do - let arrty = STArr n (typeOf efun) - strname <- emitStruct arrty - shname <- genName' "sh" emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh shsizename <- genName' "shsz" emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname) - bufpname <- genName' "bufp" - emit $ SVarDecl True (repSTy arrty ++ "_buf *") bufpname - (CECall "malloc" [CEBinop (CELit (show (fromSNat n * 8 + 8))) - "+" - (CEBinop (CELit shsizename) - "*" (CELit (show (sizeofSTy (typeOf efun)))))]) - forM_ (zip (compileShapeTupIntoArray n shname) [0::Int ..]) $ \(rhs, i) -> - emit $ SAsg (bufpname ++ "->sh[" ++ show i ++ "]") rhs - emit $ SAsg (bufpname ++ "->refc") (CELit "1") - - arrname <- genName' "arr" - emit $ SVarDecl True (repSTy arrty) arrname (CEStruct strname [("buf", CELit bufpname)]) + arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname) idxargname <- genName' "ix" (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun @@ -623,20 +611,57 @@ compile' env = \case emit $ SBlock $ pure (SVarDecl False "size_t" linivar (CELit "0")) <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (bufpname ++ "->sh")) (CELit (show dimidx)))) + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx)))) | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) <> BList funstmts - <> pure (SAsg (bufpname ++ "->a[" ++ linivar ++ "++]") funretval)) + <> pure (SAsg (arrname ++ ".buf->a[" ++ linivar ++ "++]") funretval)) return (CELit arrname) -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) - -- ESum1Inner _ e -> error "TODO" -- ESum1Inner ext (compile' e) + ESum1Inner _ e -> do + let STArr (SS n) t = typeOf e + e' <- compile' env e + argname <- genName' "sumarg" + emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e' + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray "sumres" n t (CELit shszname) + [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + + ivar <- genName' "i" + jvar <- genName' "j" + accvar <- genName' "tot" + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList + -- we have ScalIsNumeric, so it has 0 and (+) in C + [SVarDecl False (repSTy t) accvar (CELit "0") + ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" + ,SAsg (resname ++ ".buf->a[" ++ ivar ++ "]") (CELit accvar)] - -- EUnit _ e -> error "TODO" -- EUnit ext (compile' e) + return (CELit resname) + + EUnit _ e -> do + e' <- compile' env e + let typ = STArr SZ (typeOf e) + strname <- emitStruct typ + name <- genName + emit $ SVarDecl True strname name (CEStruct strname + [("buf", CECall "malloc" [CELit (show (8 + sizeofSTy (typeOf e)))])]) + emit $ SAsg (name ++ ".buf->refc") (CELit "1") + emit $ SAsg (name ++ ".buf->a[0]") e' + return (CELit name) -- EReplicate1Inner _ a b -> error "TODO" -- EReplicate1Inner ext (compile' a) (compile' b) @@ -646,7 +671,16 @@ compile' env = \case EConst _ t x -> return $ CELit $ compileScal True t x - -- EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e) + EIdx0 _ e -> do + let STArr _ t = typeOf e + e' <- compile' env e + arrname <- genName + emit $ SVarDecl True (repSTy (STArr SZ t)) arrname e' + name <- genName + emit $ SVarDecl True (repSTy t) name + (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "a") (CELit "0")) + incrementVarAlways Decrement (STArr SZ t) arrname + return (CELit name) -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) @@ -655,7 +689,7 @@ compile' env = \case arrname <- genName' "ixarr" idxname <- genName' "ixix" emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr - emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx + when (fromSNat n > 0) $ emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx resname <- genName' "ixres" emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname)) incrementVarAlways Decrement (STArr n t) arrname @@ -771,6 +805,23 @@ toLinearIdx (SS n) arrvar idxvar = -- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) -- _ +-- | The shape must have the outer dimension at the head (and the inner dimension on the right). +allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String +allocArray nameBase rank eltty shsz shape = do + when (length shape /= fromSNat rank) $ + error "allocArray: shape does not match rank" + let arrty = STArr rank eltty + strname <- emitStruct arrty + arrname <- genName' nameBase + emit $ SVarDecl True strname arrname $ CEStruct strname + [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8))) + "+" + (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])] + forM_ (zip shape [0::Int ..]) $ \(dim, i) -> + emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim + emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") + return arrname + compileShapeQuery :: SNat n -> String -> CExpr compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] compileShapeQuery (SS n) var = @@ -779,10 +830,17 @@ compileShapeQuery (SS n) var = ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] compileShapeSize :: SNat n -> String -> CExpr -compileShapeSize SZ _ = CELit "0" +compileShapeSize SZ _ = CELit "1" compileShapeSize (SS SZ) var = CELit (var ++ ".b") compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b")) +-- | Takes a variable name for the array, not the buffer. +compileArrShapeSize :: SNat n -> String -> CExpr +compileArrShapeSize SZ _ = CELit "1" +compileArrShapeSize n var = + foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") + | i <- [0 .. fromSNat n - 1]] + compileShapeTupIntoArray :: SNat n -> String -> [CExpr] compileShapeTupIntoArray = \n var -> map CELit (toList (go n var)) where |