diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 12:57:07 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 12:57:07 +0100 |
commit | 52184258f8eef227c743726770e984202e101919 (patch) | |
tree | 8011e731e2280245400a247e633fa84a3a487af8 | |
parent | 54762c98901b21468fa9ff4205107360c3096cd3 (diff) |
Compile: index into arrays
-rw-r--r-- | src/Compile.hs | 84 | ||||
-rw-r--r-- | src/Language/AST.hs | 3 |
2 files changed, 61 insertions, 26 deletions
diff --git a/src/Compile.hs b/src/Compile.hs index 2a23561..4a6b9f9 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -33,9 +33,11 @@ import Data import Interpreter.Rep --- :m *Example Compile AST.UnMonoid --- :seti -XOverloadedLabels -XGADTs --- (($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i)) +{- +:m *Example Compile AST.UnMonoid +:seti -XOverloadedLabels -XGADTs +(($ SCons (Value 2) SNil) =<<) $ compile knownEnv $ fromNamed $ lambda @(TScal TF64) #x $ body $ constArr_ @TF64 (arrayGenerate (ShNil `ShCons` 10) (\(IxNil `IxCons` i) -> fromIntegral i)) +-} -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -89,6 +91,7 @@ data CExpr | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` | CEProj CExpr String -- ^ field projection: expr.field | CEAddrOf CExpr -- ^ &expr + | CEIndex CExpr CExpr -- ^ expr[expr] | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr @@ -104,13 +107,16 @@ printStmt indent = \case SVarDecl cnst typ name rhs -> showString ((if cnst then "const " else "") ++ typ ++ " " ++ name ++ " = ") . printCExpr 0 rhs . showString ";" SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" - SBlock stmts -> - showString "{" - . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] - . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") + SBlock stmts + | null stmts -> showString "{}" + | otherwise -> + showString "{" + . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] + . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") SIf cond b1 b2 -> showString "if (" . printCExpr 0 cond . showString ") " - . printStmt indent (SBlock b1) . showString " else " . printStmt indent (SBlock b2) + . printStmt indent (SBlock b1) + . (if null b2 then id else showString " else " . printStmt indent (SBlock b2)) SVerbatim s -> showString s -- d values: @@ -132,6 +138,7 @@ printCExpr d = \case . showString "}" CEProj e name -> printCExpr 99 e . showString ("." ++ name) CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e + CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]" CECall n es -> showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" CEBinop e1 n e2 -> @@ -290,7 +297,8 @@ compileToString env expr = result_offset = align (alignmentSTy (typeOf expr)) result_offset' in ($ "") $ compose [showString "#include <stdint.h>\n" - ,showString "#include <stdlib.h>\n\n" + ,showString "#include <stdlib.h>\n" + ,showString "#include <math.h>\n\n" ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs ,showString "\n" ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] @@ -345,17 +353,20 @@ serialise topty topval ptr off k = pokeByteOff ptr off (1 :: Word8) serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do - _ <- error "TODO serialisation of arrays is wrong after refcount introduction" - pokeShape ptr off n sh - let off1 = off + 8 * fromSNat n - eltsz = sizeofSTy t - allocaBytes (shapeSize sh * sizeofSTy t) $ \arrptr -> - let loop i + let eltsz = sizeofSTy t + allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + pokeByteOff ptr off bufptr + + pokeShape bufptr 0 n sh + pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63) + + let off1 = fromSNat n * 8 + 8 + loop i | i == shapeSize sh = k | otherwise = - serialise t (vec V.! i) arrptr (off1 + i * eltsz) $ + serialise t (vec V.! i) bufptr (off1 + i * eltsz) $ loop (i+1) - in loop 0 + loop 0 (STScal sty, x) -> case sty of STI32 -> pokeByteOff ptr off (x :: Int32) >> k STI64 -> pokeByteOff ptr off (x :: Int64) >> k @@ -424,7 +435,7 @@ metricsSTy (STEither a b) = metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr n _) = (8, fromSNat n * 8 + 8) +metricsSTy (STArr _ _) = (8, 8) metricsSTy (STScal sty) = case sty of STI32 -> (4, 4) STI64 -> (8, 8) @@ -453,9 +464,8 @@ compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case EVar _ t i -> do let Const var = slistIdx env i - case t of - STArr{} -> return $ CELit ("(++" ++ var ++ "->buf.refc, " ++ var ++ ")") - _ -> return $ CELit var + incrementVarAlways Increment t var + return $ CELit var ELet _ rhs body -> do e <- compile' env rhs @@ -595,7 +605,16 @@ compile' env = \case -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) - -- EIdx _ a b -> error "TODO" -- EIdx ext (compile' a) (compile' b) + EIdx _ earr eidx -> do + let STArr n t = typeOf earr + arrname <- genName + idxname <- genName + emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr + emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx + resname <- genName + emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname)) + incrementVarAlways Decrement (STArr n t) arrname + return (CELit resname) -- EShape _ e -> error "TODO" -- EShape ext (compile' e) @@ -671,9 +690,10 @@ makeArrayTree (STAccum _) = ATNoop incrementVar' :: Increment -> String -> ArrayTree -> CompM () incrementVar' inc path ATArray = - let op = case inc of Increment -> "++" - Decrement -> "--" - in emit $ SVerbatim (path ++ "->buf.refc" ++ op ++ ";") + case inc of + Increment -> emit $ SVerbatim (path ++ ".buf->refc++;") + Decrement -> + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free(" ++ path ++ ".buf);" incrementVar' _ _ ATNoop = pure () incrementVar' inc path (ATProj field t) = incrementVar' inc (path ++ "." ++ field) t incrementVar' inc path (ATCondTag t1 t2) = do @@ -682,6 +702,20 @@ incrementVar' inc path (ATCondTag t1 t2) = do emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2) incrementVar' inc path (ATBoth t1 t2) = incrementVar' inc path t1 >> incrementVar' inc path t2 +toLinearIdx :: SNat n -> String -> String -> CExpr +toLinearIdx SZ _ _ = CELit "0" +toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") +toLinearIdx (SS n) arrvar idxvar = + CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) + "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n))))) + "+" (CELit (idxvar ++ ".b")) + +-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr +-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] +-- fromLinearIdx (SS n) arrvar idxvar = do +-- name <- genName +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) +-- _ compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 022e797..387915b 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -17,7 +17,7 @@ module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text)) +import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) import Array import AST @@ -77,6 +77,7 @@ deriving instance Show (NExpr env t) type family Lookup name env where Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") Lookup name ('(name, t) : env) = t Lookup name (_ : env) = Lookup name env |