summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-02 12:57:07 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-02 12:57:07 +0100
commit52184258f8eef227c743726770e984202e101919 (patch)
tree8011e731e2280245400a247e633fa84a3a487af8
parent54762c98901b21468fa9ff4205107360c3096cd3 (diff)
Compile: index into arrays
-rw-r--r--src/Compile.hs84
-rw-r--r--src/Language/AST.hs3
2 files changed, 61 insertions, 26 deletions
diff --git a/src/Compile.hs b/src/Compile.hs
index 2a23561..4a6b9f9 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -33,9 +33,11 @@ import Data
import Interpreter.Rep
--- :m *Example Compile AST.UnMonoid
--- :seti -XOverloadedLabels -XGADTs
--- (($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i))
+{-
+:m *Example Compile AST.UnMonoid
+:seti -XOverloadedLabels -XGADTs
+(($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i))
+-}
-- In shape and index arrays, the innermost dimension is on the right (last index).
@@ -89,6 +91,7 @@ data CExpr
| CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}`
| CEProj CExpr String -- ^ field projection: expr.field
| CEAddrOf CExpr -- ^ &expr
+ | CEIndex CExpr CExpr -- ^ expr[expr]
| CECall String [CExpr] -- ^ function(arg1, ..., argn)
| CEBinop CExpr String CExpr -- ^ expr + expr
| CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr
@@ -104,13 +107,16 @@ printStmt indent = \case
SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";"
SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";")
SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"
- SBlock stmts ->
- showString "{"
- . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
- . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
+ SBlock stmts
+ | null stmts -> showString "{}"
+ | otherwise ->
+ showString "{"
+ . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]
+ . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")
SIf cond b1 b2 ->
showString "if (" . printCExpr 0 cond . showString ") "
- . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2)
+ . printStmt indent (SBlock b1)
+ . (if null b2 then id else showString " else " . printStmt indent (SBlock b2))
SVerbatim s -> showString s
-- d values:
@@ -132,6 +138,7 @@ printCExpr d = \case
. showString "}"
CEProj e name -> printCExpr 99 e . showString ("." ++ name)
CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e
+ CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"
CECall n es ->
showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"
CEBinop e1 n e2 ->
@@ -290,7 +297,8 @@ compileToString env expr =
result_offset = align (alignmentSTy (typeOf expr)) result_offset'
in ($ "") $ compose
[showString "#include <stdint.h>\n"
- ,showString "#include <stdlib.h>\n\n"
+ ,showString "#include <stdlib.h>\n"
+ ,showString "#include <math.h>\n\n"
,compose $ map (\sd -> printStructDecl sd . showString "\n") structs
,showString "\n"
,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]
@@ -345,17 +353,20 @@ serialise topty topval ptr off k =
pokeByteOff ptr off (1 :: Word8)
serialise t x ptr (off + alignmentSTy t) k
(STArr n t, Array sh vec) -> do
- _ <- error "TODO serialisation of arrays is wrong after refcount introduction"
- pokeShape ptr off n sh
- let off1 = off + 8 * fromSNat n
- eltsz = sizeofSTy t
- allocaBytes (shapeSize sh * sizeofSTy t) $ \arrptr ->
- let loop i
+ let eltsz = sizeofSTy t
+ allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do
+ pokeByteOff ptr off bufptr
+
+ pokeShape bufptr 0 n sh
+ pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63)
+
+ let off1 = fromSNat n * 8 + 8
+ loop i
| i == shapeSize sh = k
| otherwise =
- serialise t (vec V.! i) arrptr (off1 + i * eltsz) $
+ serialise t (vec V.! i) bufptr (off1 + i * eltsz) $
loop (i+1)
- in loop 0
+ loop 0
(STScal sty, x) -> case sty of
STI32 -> pokeByteOff ptr off (x :: Int32) >> k
STI64 -> pokeByteOff ptr off (x :: Int64) >> k
@@ -424,7 +435,7 @@ metricsSTy (STEither a b) =
metricsSTy (STMaybe t) =
let (a, s) = metricsSTy t
in (a, a + s) -- the union after the tag byte is aligned
-metricsSTy (STArr n _) = (8, fromSNat n * 8 + 8)
+metricsSTy (STArr _ _) = (8, 8)
metricsSTy (STScal sty) = case sty of
STI32 -> (4, 4)
STI64 -> (8, 8)
@@ -453,9 +464,8 @@ compile' :: SList (Const String) env -> Ex env t -> CompM CExpr
compile' env = \case
EVar _ t i -> do
let Const var = slistIdx env i
- case t of
- STArr{} -> return $ CELit ("(++" ++ var ++ "->buf.refc, " ++ var ++ ")")
- _ -> return $ CELit var
+ incrementVarAlways Increment t var
+ return $ CELit var
ELet _ rhs body -> do
e <- compile' env rhs
@@ -595,7 +605,16 @@ compile' env = \case
-- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b)
- -- EIdx _ a b -> error "TODO" -- EIdx ext (compile' a) (compile' b)
+ EIdx _ earr eidx -> do
+ let STArr n t = typeOf earr
+ arrname <- genName
+ idxname <- genName
+ emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr
+ emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx
+ resname <- genName
+ emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname))
+ incrementVarAlways Decrement (STArr n t) arrname
+ return (CELit resname)
-- EShape _ e -> error "TODO" -- EShape ext (compile' e)
@@ -671,9 +690,10 @@ makeArrayTree (STAccum _) = ATNoop
incrementVar' :: Increment -> String -> ArrayTree -> CompM ()
incrementVar' inc path ATArray =
- let op = case inc of Increment -> "++"
- Decrement -> "--"
- in emit $ SVerbatim (path ++ "->buf.refc" ++ op ++ ";")
+ case inc of
+ Increment -> emit $ SVerbatim (path ++ ".buf->refc++;")
+ Decrement ->
+ emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);"
incrementVar' _ _ ATNoop = pure ()
incrementVar' inc path (ATProj field t) = incrementVar' inc (path ++ "." ++ field) t
incrementVar' inc path (ATCondTag t1 t2) = do
@@ -682,6 +702,20 @@ incrementVar' inc path (ATCondTag t1 t2) = do
emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2)
incrementVar' inc path (ATBoth t1 t2) = incrementVar' inc path t1 >> incrementVar' inc path t2
+toLinearIdx :: SNat n -> String -> String -> CExpr
+toLinearIdx SZ _ _ = CELit "0"
+toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b")
+toLinearIdx (SS n) arrvar idxvar =
+ CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a"))
+ "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n)))))
+ "+" (CELit (idxvar ++ ".b"))
+
+-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr
+-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) []
+-- fromLinearIdx (SS n) arrvar idxvar = do
+-- name <- genName
+-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))
+-- _
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 022e797..387915b 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -17,7 +17,7 @@ module Language.AST where
import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
-import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text))
+import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..))
import Array
import AST
@@ -77,6 +77,7 @@ deriving instance Show (NExpr env t)
type family Lookup name env where
Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
+ Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
Lookup name ('(name, t) : env) = t
Lookup name (_ : env) = Lookup name env