diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 00:21:36 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-02 00:21:36 +0100 | 
| commit | 0fffb5731271a551afcf08878cb021ead8dd1dae (patch) | |
| tree | fd4fff6010db0fb63515930708b3d8bfd234c367 /src | |
| parent | 0ebdcff2aa06ee95f95597f2984e2cd335899d37 (diff) | |
compile: WIP reference-counted arrays
Diffstat (limited to 'src')
| -rw-r--r-- | src/Array.hs | 7 | ||||
| -rw-r--r-- | src/Compile.hs | 193 | ||||
| -rw-r--r-- | src/Compile/Exec.hs | 2 | ||||
| -rw-r--r-- | src/Data.hs | 3 | 
4 files changed, 158 insertions, 47 deletions
| diff --git a/src/Array.hs b/src/Array.hs index 82c3f31..059600f 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -63,6 +63,13 @@ emptyShape (SS m) = emptyShape m `ShCons` 0  enumShape :: Shape n -> [Index n]  enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] +shapeToList :: Shape n -> [Int] +shapeToList = go [] +  where +    go :: [Int] -> Shape n -> [Int] +    go suff ShNil = suff +    go suff (sh `ShCons` n) = go (n:suff) sh +  -- | TODO: this Vector is a boxed vector, which is horrendously inefficient.  data Array (n :: Nat) t = Array (Shape n) (Vector t) diff --git a/src/Compile.hs b/src/Compile.hs index 582a1df..564f697 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -67,8 +67,8 @@ data Stmt    = SVarDecl Bool String String CExpr  -- ^ const, type, variable name, right-hand side    | SVarDeclUninit String String  -- ^ type, variable name (no initialiser)    | SAsg String CExpr  -- ^ variable name, right-hand side -  | SBlock [Stmt] -  | SIf CExpr [Stmt] [Stmt] +  | SBlock (Bag Stmt) +  | SIf CExpr (Bag Stmt) (Bag Stmt)    | SVerbatim String  -- ^ no implicit ';', just printed as-is    deriving (Show) @@ -76,6 +76,7 @@ data CExpr    = CELit String  -- ^ inserted as-is, assumed no parentheses needed    | CEStruct String [(String, CExpr)]  -- ^ struct construction literal: `(name){.field=expr}`    | CEProj CExpr String  -- ^ field projection: expr.field +  | CEAddrOf CExpr  -- ^ &expr    | CECall String [CExpr]  -- ^ function(arg1, ..., argn)    | CEBinop CExpr String CExpr  -- ^ expr + expr    | CEIf CExpr CExpr CExpr  -- ^ expr ? expr : expr @@ -93,7 +94,7 @@ printStmt indent = \case    SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";"    SBlock stmts ->      showString "{" -    . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- stmts] +    . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts]      . showString ("\n" ++ replicate (2*indent) ' ' ++ "}")    SIf cond b1 b2 ->      showString "if (" . printCExpr 0 cond . showString ") " @@ -104,6 +105,7 @@ printStmt indent = \case  -- * 0: top level  -- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability)  -- * 2-...: various operators (see precTable) +-- * 80: address-of operator (&)  -- * 98: inside unknown operator  -- * 99: left of a field projection  -- Unlisted operators are conservatively written with full parentheses. @@ -117,6 +119,7 @@ printCExpr d = \case                                                 | (n, e) <- pairs])        . showString "}"    CEProj e name -> printCExpr 99 e . showString ("." ++ name) +  CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e    CECall n es ->      showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")"    CEBinop e1 n e2 -> @@ -173,22 +176,23 @@ genStructName = \t -> "ty_" ++ gen t where      TBool -> "b"    gen (TAccum t) = 'C' : gen t -genStruct :: String -> Ty -> Maybe StructDecl +genStruct :: String -> Ty -> [StructDecl]  genStruct name topty = case topty of    TNil -> -    Just $ StructDecl name "" com +    [StructDecl name "" com]    TPair a b -> -    Just $ StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com +    [StructDecl name (repTy a ++ " a; " ++ repTy b ++ " b;") com]    TEither a b ->  -- 0 -> a, 1 -> b -    Just $ StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com +    [StructDecl name ("uint8_t tag; union { " ++ repTy a ++ " a; " ++ repTy b ++ " b; };") com]    TMaybe t ->  -- 0 -> nothing, 1 -> just -    Just $ StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com +    [StructDecl name ("uint8_t tag; " ++ repTy t ++ " a;") com]    TArr n t -> -    Just $ StructDecl name ("size_t sh[" ++ show (fromNat n) ++ "]; " ++ repTy t ++ " *a;") com +    [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromNat n) ++ "]; size_t refc; " ++ repTy t ++ " *a;") com +    ,StructDecl name (name ++ "_buf *buf;") com]    TScal _ -> -    Nothing +    []    TAccum t -> -    Just $ StructDecl name (repTy t ++ " a;") com +    [StructDecl name (repTy t ++ " a;") com]    where      com = ppTy 0 topty @@ -202,7 +206,8 @@ genStructs ty = do    if seen      then pure ()      else do -      -- already mark this struct as generated now, so we don't generate it twice +      -- already mark this struct as generated now, so we don't generate it +      -- twice (unnecessary because no recursive types, but y'know)        lift $ modify (Set.insert name)        case ty of @@ -214,13 +219,14 @@ genStructs ty = do          TScal _ -> pure ()          TAccum t -> genStructs t -      tell (maybe mempty pure (genStruct name ty)) +      tell (BList (genStruct name ty))  genAllStructs :: Foldable t => t Ty -> [StructDecl]  genAllStructs tys = toList $ evalState (execWriterT (mapM_ genStructs tys)) mempty  data CompState = CompState    { csStructs :: Set Ty +  , csTopLevelDecls :: Bag String    , csStmts :: Bag Stmt    , csNextId :: Int }    deriving (Show) @@ -230,8 +236,11 @@ type CompM a = State CompState a  genId :: CompM Int  genId = state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +genName' :: String -> CompM String +genName' prefix = (prefix ++) . show <$> genId +  genName :: CompM String -genName = ('x' :) . show <$> genId +genName = genName' "x"  emit :: Stmt -> CompM ()  emit stmt = modify $ \s -> s { csStmts = csStmts s <> pure stmt } @@ -249,13 +258,16 @@ emitStruct ty = do    modify $ \s -> s { csStructs = Set.insert ty' (csStructs s) }    return (genStructName ty') +emitTLD :: String -> CompM () +emitTLD decl = modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } +  nameEnv :: SList f env -> SList (Const String) env  nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1))  compileToString :: SList STy env -> Ex env t -> String  compileToString env expr =    let args = nameEnv env -      (res, s) = runState (compile' args expr) (CompState mempty mempty 1) +      (res, s) = runState (compile' args expr) (CompState mempty mempty mempty 1)        structs = genAllStructs (csStructs s <> Set.fromList (unSList unSTy env))        (arg_pairs, arg_metrics) = @@ -268,6 +280,7 @@ compileToString env expr =         ,showString "#include <stdlib.h>\n\n"         ,compose $ map (\sd -> printStructDecl sd . showString "\n") structs         ,showString "\n" +       ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)]         ,showString $            "static " ++ repSTy (typeOf expr) ++ " typed_kernel(" ++              intercalate ", " (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ @@ -319,6 +332,7 @@ serialise topty topval ptr off k =        pokeByteOff ptr off (1 :: Word8)        serialise t x ptr (off + alignmentSTy t) k      (STArr n t, Array sh vec) -> do +      _ <- error "TODO serialisation of arrays is wrong after refcount introduction"        pokeShape ptr off n sh        let off1 = off + 8 * fromSNat n            eltsz = sizeofSTy t @@ -358,6 +372,7 @@ deserialise topty ptr off =          then return Nothing          else Just <$> deserialise t ptr (off + alignmentSTy t)      STArr n t -> do +      _ <- error "TODO deserialisation of arrays is wrong after refcount introduction"        sh <- peekShape ptr off n        let off1 = off + 8 * fromSNat n            eltsz = sizeofSTy t @@ -420,13 +435,19 @@ peekShape ptr off = \case  compile' :: SList (Const String) env -> Ex env t -> CompM CExpr  compile' env = \case -  EVar _ _ i -> return $ CELit (getConst (slistIdx env i)) +  EVar _ t i -> do +    let Const var = slistIdx env i +    case t of +      STArr{} -> return $ CELit ("(++" ++ var ++ ".refc, " ++ var ++ ")") +      _ -> return $ CELit var    ELet _ rhs body -> do      e <- compile' env rhs      var <- genName      emit $ SVarDecl True (repSTy (typeOf rhs)) var e -    compile' (Const var `SCons` env) body +    rete <- compile' (Const var `SCons` env) body +    releaseVarAlways (typeOf rhs) var +    return rete    EPair _ a b -> do      name <- emitStruct (STPair (typeOf a) (typeOf b)) @@ -434,9 +455,25 @@ compile' env = \case      e2 <- compile' env b      return $ CEStruct name [("a", e1), ("b", e2)] -  EFst _ e -> CEProj <$> compile' env e <*> pure "a" +  EFst _ e -> do +    let STPair _ t2 = typeOf e +    e' <- compile' env e +    case releaseVar t2 of +      Nothing -> return $ CEProj e' "a" +      Just f -> do var <- genName +                   emit $ SVarDecl True (repSTy (typeOf e)) var e' +                   f (var ++ ".b") +                   return $ CEProj (CELit var) "a" -  ESnd _ e -> CEProj <$> compile' env e <*> pure "b" +  ESnd _ e -> do +    let STPair t1 _ = typeOf e +    e' <- compile' env e +    case releaseVar t1 of +      Nothing -> return $ CEProj e' "b" +      Just f -> do var <- genName +                   emit $ SVarDecl True (repSTy (typeOf e)) var e' +                   f (var ++ ".a") +                   return $ CEProj (CELit var) "b"    ENil _ -> do      name <- emitStruct STNil @@ -459,28 +496,28 @@ compile' env = \case      retvar <- genName      emit $ SVarDeclUninit (repSTy (typeOf a)) retvar      emit $ SIf e1 -             (stmts2 <> pure (SAsg retvar e2)) -             (stmts3 <> pure (SAsg retvar e3)) +             (BList stmts2 <> pure (SAsg retvar e2)) +             (BList stmts3 <> pure (SAsg retvar e3))      return (CELit retvar)    ECase _ e a b -> do      let STEither t1 t2 = typeOf e      e1 <- compile' env e      var <- genName -    fieldvar <- genName -    (e2, stmts2) <- scope $ compile' (Const fieldvar `SCons` env) a -    (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b +    -- I know those are not variable names, but it's fine, probably +    (e2, stmts2) <- scope $ compile' (Const (var ++ ".a") `SCons` env) a +    (e3, stmts3) <- scope $ compile' (Const (var ++ ".b") `SCons` env) b +    ((), stmtsRel1) <- scope $ releaseVarAlways t1 (var ++ ".a") +    ((), stmtsRel2) <- scope $ releaseVarAlways t2 (var ++ ".b")      retvar <- genName      emit $ SVarDeclUninit (repSTy (typeOf a)) retvar      emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)                  <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) -                           (pure (SVarDecl True (repSTy t1) fieldvar -                                           (CEProj (CELit var) "a")) -                            <> stmts2 +                           (BList stmts2 +                            <> BList stmtsRel1                              <> pure (SAsg retvar e2)) -                           (pure (SVarDecl True (repSTy t2) fieldvar -                                           (CEProj (CELit var) "b")) -                            <> stmts3 +                           (BList stmts3 +                            <> BList stmtsRel2                              <> pure (SAsg retvar e3))))      return (CELit retvar) @@ -494,26 +531,37 @@ compile' env = \case      return $ CEStruct name [("tag", CELit "1"), ("a", e1)]    EMaybe _ a b e -> do +    let STMaybe t = typeOf e      e1 <- compile' env e      var <- genName -    fieldvar <- genName      (e2, stmts2) <- scope $ compile' env a -    (e3, stmts3) <- scope $ compile' (Const fieldvar `SCons` env) b +    (e3, stmts3) <- scope $ compile' (Const (var ++ ".a") `SCons` env) b +    ((), stmtsRel) <- scope $ releaseVarAlways t (var ++ ".a")      retvar <- genName      emit $ SVarDeclUninit (repSTy (typeOf a)) retvar      emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1)                  <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) -                           (stmts2 +                           (BList stmts2                              <> pure (SAsg retvar e2)) -                           (pure (SVarDecl True (repSTy (typeOf b)) fieldvar -                                           (CEProj (CELit var) "a")) -                            <> stmts3 +                           (BList stmts3 +                            <> BList stmtsRel                              <> pure (SAsg retvar e3))))      return (CELit retvar) -  -- EConstArr _ n t arr -> do -  --   name <- emitStruct (STArr n (STScal t)) -  --   error "TODO" +  EConstArr _ n t (Array sh vec) -> do +    strname <- emitStruct (STArr n (STScal t)) +    tldname <- genName' "carray" +    emitTLD $ "static const " ++ repSTy (STScal t) ++ " " ++ +              tldname ++ "[" ++ show (shapeSize sh) ++ "] = {" ++ +              intercalate "," (map (compileScal False t) (toList vec)) ++ +              "};" +    -- Give it a refcount of _half_ the size_t max, so that it can be +    -- incremented and decremented at will and will "never" reach anything +    -- where something happens +    emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ "_buf = " ++ +              "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ +              ".refc = SIZE_MAX/2, .a = " ++ tldname ++ "};" +    return (CEStruct strname [("buf", CEAddrOf (CELit (tldname ++ "_buf")))])    -- EBuild _ n a b -> error "TODO" -- genStruct (STArr n (typeOf b)) <> EBuild ext n (compile' a) (compile' b) @@ -529,12 +577,7 @@ compile' env = \case    -- EMinimum1Inner _ e -> error "TODO" -- EMinimum1Inner ext (compile' e) -  EConst _ t x -> case t of -    STI32 -> return $ CELit $ "(int32_t)" ++ show x -    STI64 -> return $ CELit $ "(int64_t)" ++ show x -    STF32 -> return $ CELit $ show x ++ "f" -    STF64 -> return $ CELit $ show x -    STBool -> return $ CELit $ if x then "1" else "0" +  EConst _ t x -> return $ CELit $ compileScal True t x    -- EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e) @@ -571,6 +614,57 @@ compile' env = \case    _ -> error "Compile: not implemented" +-- | Decrement reference counts in the components of the given variable. +releaseVar :: STy a -> Maybe (String -> CompM ()) +releaseVar ty = +  let tree = makeReleaseTree ty +  in case tree of RTNoop -> Nothing +                  _ -> Just $ \var -> releaseVar' var tree + +releaseVarAlways :: STy a -> String -> CompM () +releaseVarAlways ty var = maybe (pure ()) ($ var) (releaseVar ty) + +data ReleaseTree = RTArray  -- ^ we've arrived at an array we need to decrement the refcount of +                 | RTNoop  -- ^ don't do anything here +                 | RTProj String ReleaseTree  -- ^ descend one field deeper +                 | RTCondTag ReleaseTree ReleaseTree  -- ^ if tag is 0, first; if 1, second +                 | RTBoth ReleaseTree ReleaseTree  -- ^ do both these paths + +smartRTProj :: String -> ReleaseTree -> ReleaseTree +smartRTProj _ RTNoop = RTNoop +smartRTProj field t = RTProj field t + +smartRTCondTag :: ReleaseTree -> ReleaseTree -> ReleaseTree +smartRTCondTag RTNoop RTNoop = RTNoop +smartRTCondTag t t' = RTCondTag t t' + +smartRTBoth :: ReleaseTree -> ReleaseTree -> ReleaseTree +smartRTBoth RTNoop t = t +smartRTBoth t RTNoop = t +smartRTBoth t t' = RTBoth t t' + +makeReleaseTree :: STy a -> ReleaseTree +makeReleaseTree STNil = RTNoop +makeReleaseTree (STPair a b) = smartRTBoth (smartRTProj "a" (makeReleaseTree a)) +                                           (smartRTProj "b" (makeReleaseTree b)) +makeReleaseTree (STEither a b) = smartRTCondTag (smartRTProj "a" (makeReleaseTree a)) +                                                (smartRTProj "b" (makeReleaseTree b)) +makeReleaseTree (STMaybe t) = smartRTCondTag RTNoop (makeReleaseTree t) +makeReleaseTree (STArr _ _) = RTArray +makeReleaseTree (STScal _) = RTNoop +makeReleaseTree (STAccum _) = RTNoop + +releaseVar' :: String -> ReleaseTree -> CompM () +releaseVar' path RTArray = emit $ SVerbatim (path ++ "--;") +releaseVar' _ RTNoop = pure () +releaseVar' path (RTProj field t) = releaseVar' (path ++ "." ++ field) t +releaseVar' path (RTCondTag t1 t2) = do +  ((), stmts1) <- scope $ releaseVar' path t1 +  ((), stmts2) <- scope $ releaseVar' path t2 +  emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) (BList stmts1) (BList stmts2) +releaseVar' path (RTBoth t1 t2) = releaseVar' path t1 >> releaseVar' path t2 + +  compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr  compileOpGeneral op e1 = do    let unary cop = return @(State CompState) $ CECall cop [e1] @@ -616,5 +710,14 @@ compileOpPair op e1 e2 = do      OIDiv _ -> binary "/"      _ -> error "compileOpPair: got unary operator" +-- | Bool: whether to ensure that the literal itself already has the appropriate type +compileScal :: Bool -> SScalTy t -> ScalRep t -> String +compileScal pedantic typ x = case typ of +  STI32 -> (if pedantic then "(int32_t)" else "") ++ show x +  STI64 -> (if pedantic then "(int64_t)" else "") ++ show x +  STF32 -> show x ++ "f" +  STF64 -> show x +  STBool -> if x then "1" else "0" +  compose :: Foldable t => t (a -> a) -> a -> a  compose = foldr (.) id diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index 163be2b..83fcdad 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -28,7 +28,7 @@ buildKernel csource funnames = do    path <- mkdtemp template    let outso = path ++ "/out.so" -  let args = ["-O3", "-march=native", "-shared", "-fPIC", "-x", "c", "-o", outso, "-"] +  let args = ["-O3", "-march=native", "-shared", "-fPIC", "-std=c99", "-x", "c", "-o", outso, "-"]    _ <- readProcess "gcc" args csource    hPutStrLn stderr $ "[chad] loading kernel " ++ path diff --git a/src/Data.hs b/src/Data.hs index 8005737..00790fe 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -129,7 +129,7 @@ vecGenerate = \n f -> go n f SZ  unsafeCoerceRefl :: a :~: b  unsafeCoerceRefl = unsafeCoerce Refl -data Bag t = BNone | BOne t | BTwo (Bag t) (Bag t) | BMany [Bag t] +data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t]    deriving (Show, Functor, Foldable, Traversable)  -- | This instance is mostly there just for 'pure' @@ -139,6 +139,7 @@ instance Applicative Bag where    BOne f <*> b = f <$> b    BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b)    BMany bs <*> b = BMany (map (<*> b) bs) +  BList bs <*> b = BMany (map (<$> b) bs)  instance Semigroup (Bag t) where (<>) = BTwo  instance Monoid (Bag t) where mempty = BNone | 
