diff options
author | Tom Smeding <tom@tomsmeding.com> | 2021-09-19 18:06:03 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2021-09-19 18:06:03 +0200 |
commit | 956a60dc5253da43dc0fddaecf88116597023fdf (patch) | |
tree | f56402103483c853a0bdd7551092418025156e51 /SC |
Initial
Diffstat (limited to 'SC')
-rw-r--r-- | SC/Acc.hs | 124 | ||||
-rw-r--r-- | SC/Defs.hs | 116 | ||||
-rw-r--r-- | SC/Exp.hs | 170 | ||||
-rw-r--r-- | SC/Monad.hs | 25 |
4 files changed, 435 insertions, 0 deletions
diff --git a/SC/Acc.hs b/SC/Acc.hs new file mode 100644 index 0000000..955c6da --- /dev/null +++ b/SC/Acc.hs @@ -0,0 +1,124 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +module SC.Acc where + +import qualified Data.Array.Accelerate.AST as A +import Data.Array.Accelerate.AST.LeftHandSide +import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Type +import Data.Bifunctor +import qualified Data.Set as Set + +import qualified Language.C as C +import SC.Defs +import SC.Monad + + +data Command + = CChunk [C.FunDef] -- ^ Emitted top-level function definitions + [C.Stmt] -- ^ Code to execute + [C.Name] -- ^ Array variables used + | CAlloc [C.FunDef] -- ^ Emitted top-level function definitions + C.Type -- ^ Element type of the allocated array + C.Name -- ^ Variable to store it in + C.StExpr -- ^ Code that computes the array size + | CKeepalive C.Name -- ^ Never deallocate this + | CDealloc C.Name + deriving (Show) + +insertDeallocs :: [Command] -> [Command] +insertDeallocs cmds = + let allocated = Set.fromList [n | CAlloc _ _ n _ <- cmds] + `Set.union` Set.fromList [n | CKeepalive n <- cmds] + in fst $ foldr + (\cmd (rest, done) -> case cmd of + CChunk _ _ used -> + let todealloc = filter (\n -> n `Set.member` allocated && + n `Set.notMember` done) + used + in (cmd : map CDealloc todealloc ++ rest + ,done `Set.union` Set.fromList todealloc) + CAlloc _ _ name _ + | name `Set.notMember` done -> (rest, done) -- unused alloc + | otherwise -> (cmd : rest, Set.delete name done) + CKeepalive _ -> (rest, done) -- already handled above in @allocated@ + CDealloc _ -> error "insertDeallocs: CDealloc found") + ([], mempty) cmds + +compileCommands :: [Command] -> ([C.FunDef], [C.Stmt]) +compileCommands [] = ([], []) +compileCommands (CChunk defs code _ : cmds) = + bimap (defs ++) (code ++) (compileCommands cmds) +compileCommands (CAlloc defs typ name (C.StExpr szstmts szexpr) : cmds) = + let allocstmt = C.SDecl (C.TPtr typ) name + (Just (C.ECall (C.Name "malloc") [C.EOp szexpr "*" (C.ESizeOf typ)])) + in bimap (defs ++) ((szstmts ++ [allocstmt]) ++) (compileCommands cmds) +compileCommands (CDealloc name : cmds) = + second ([C.SCall (C.Name "free") [C.EVar name]] ++) (compileCommands cmds) +compileCommands (CKeepalive _ : cmds) = compileCommands cmds + + +compileAcc' :: AVarEnv aenv -> TupANames t -> A.OpenAcc aenv t -> SC [Command] +compileAcc' aenv dest (A.OpenAcc acc) = compilePAcc' aenv dest acc + +compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> SC [Command] +compilePAcc' aenv destnames = \case + A.Alet lhs rhs body -> do + (names, aenv') <- genVarsAEnv lhs aenv + let sts1 = [C.SDecl t n Nothing | TypedAName t n <- itupList names] + let cmds1 = [CChunk [] sts1 []] + cmds2 <- compileAcc' aenv names rhs + cmds3 <- compileAcc' aenv' destnames body + return (cmds1 ++ cmds2 ++ cmds3) + + A.Avar (Var _ idx) -> + return (Right ([], ITupSingle (C.EVar (aveprj aenv idx)))) + + A.Apair a b -> do + res1 <- compileAcc' aenv a + res2 <- compileAcc' aenv b + return (Left (\case + ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 + ITupIgnore -> [] + ITupSingle _ -> error "wat")) + + _ -> throw "Unsupported Acc constructor" + where + toStExprs :: TypeR t -> Either (ANames t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) + toStExprs ty (Left fun) = do + names <- genAVars ty + let sts1 = fun names + return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) + toStExprs _ (Right pair) = return pair + + toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] + toStoring (Left f) = f + toStoring (Right (sts, exs)) = (sts ++) . flip go exs + where + go :: Names t -> Exprs t -> [C.Stmt] + go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] + go (ITupSingle _) _ = error "wat" + go ITupIgnore _ = [] + go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 + go (ITupPair _ _) _ = error "wat" + +genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') +genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) +genVarsAEnv (LeftHandSideSingle (ArrayR _ ty)) env = do + name <- genName "a" + ty' <- cvtType ty + return (ITupSingle (TypedAName ty' name), AVEPush _ name env) +genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do + (n1, env1) <- genVarsAEnv lhs1 env + (n2, env2) <- genVarsAEnv lhs2 env1 + return (ITupPair n1 n2, env2) + +genAVars :: TypeR t -> SC (ANames t) +genAVars TupRunit = return ITupIgnore +genAVars (TupRsingle ty) = genAVar ty +genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 + +genAVar :: ScalarType t -> SC (ANames t) +genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a") diff --git a/SC/Defs.hs b/SC/Defs.hs new file mode 100644 index 0000000..685d408 --- /dev/null +++ b/SC/Defs.hs @@ -0,0 +1,116 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +module SC.Defs where + +import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Type + +import qualified Language.C as C +import Language.C (Name(..)) +import SC.Monad + + +-- ENVIRONMENTS +-- ------------ + +data AVarEnv env where + AVENil :: AVarEnv () + AVEPush :: ShNames sh -> ANames t -> AVarEnv env -> AVarEnv (env, Array sh t) + +aveprj :: AVarEnv env -> Idx env (Array sh t) -> (ShNames sh, ANames t) +aveprj (AVEPush shn n _) ZeroIdx = (shn, n) +aveprj (AVEPush _ _ aenv) (SuccIdx idx) = aveprj aenv idx + +data VarEnv env where + VENil :: VarEnv () + VEPush :: Name -> VarEnv env -> VarEnv (env, t) + +veprj :: VarEnv env -> Idx env t -> Name +veprj (VEPush n _) ZeroIdx = n +veprj (VEPush _ env) (SuccIdx idx) = veprj env idx + + +-- IGNORE TUPLES +-- ------------- + +data ITup s t where + ITupPair :: ITup s a -> ITup s b -> ITup s (a, b) + ITupSingle :: s -> ITup s a + ITupIgnore :: ITup s a + +itupfold :: (forall a. f a) -> (forall a. s -> f a) -> (forall a b. f a -> f b -> f (a, b)) + -> ITup s t -> f t +itupfold z _ _ ITupIgnore = z +itupfold _ f _ (ITupSingle x) = f x +itupfold z f g (ITupPair a b) = g (itupfold z f g a) (itupfold z f g b) + +itupmap :: (s1 -> s2) -> ITup s1 t -> ITup s2 t +itupmap f = itupfold ITupIgnore (ITupSingle . f) ITupPair + +itupList :: ITup s t -> [s] +itupList (ITupPair t1 t2) = itupList t1 ++ itupList t2 +itupList (ITupSingle x) = [x] +itupList ITupIgnore = [] + +data TypedName = TypedName C.Type Name +type Names = ITup TypedName +type ANames = ITup TypedAName + +type Exprs = ITup C.Expr + +-- Type is a pointer type +data TypedAName = TypedAName C.Type Name + +data TupANames t where + ANPair :: TupANames a -> TupANames b -> TupANames (a, b) + ANArray :: ShNames sh -> ITup TypedAName t -> TupANames (Array sh t) + ANIgnore :: TupANames a + +-- Shape names and data array names +tupanamesList :: TupANames t -> ([TypedName], [TypedAName]) +tupanamesList (ANPair a b) = + let (shn1, an1) = tupanamesList a + (shn2, an2) = tupanamesList b + in (shn1 ++ shn2, an1 ++ an2) +tupanamesList (ANArray shn ns) = (shnamesList shn, itupList ns) +tupanamesList ANIgnore = ([], []) + +data ShNames sh where + ShZ :: ShNames () + ShS :: Name -> ShNames sh -> ShNames (sh, Int) + +shnamesList :: ShNames sh -> [TypedName] +shnamesList ShZ = [] +shnamesList (ShS n shns) = TypedName (C.TInt C.B64) n : shnamesList shns + + +-- GENERATING VARIABLE NAMES +-- ------------------------- + +genName :: String -> SC Name +genName prefix = Name . (prefix ++) . show <$> genId + + +-- TYPE CONVERSION +-- --------------- + +cvtType :: ScalarType t -> SC C.Type +cvtType (SingleScalarType (NumSingleType (IntegralNumType it))) = return (cvtIT it) + where cvtIT :: IntegralType t -> C.Type + cvtIT TypeInt = C.TInt C.B64 + cvtIT TypeInt8 = C.TInt C.B8 + cvtIT TypeInt16 = C.TInt C.B16 + cvtIT TypeInt32 = C.TInt C.B32 + cvtIT TypeInt64 = C.TInt C.B64 + cvtIT TypeWord = C.TUInt C.B64 + cvtIT TypeWord8 = C.TUInt C.B8 + cvtIT TypeWord16 = C.TUInt C.B16 + cvtIT TypeWord32 = C.TUInt C.B32 + cvtIT TypeWord64 = C.TUInt C.B64 +cvtType (SingleScalarType (NumSingleType (FloatingNumType ft))) = cvtFT ft + where cvtFT :: FloatingType t -> SC C.Type + cvtFT TypeHalf = throw "Half floats not supported" + cvtFT TypeFloat = return C.TFloat + cvtFT TypeDouble = return C.TDouble +cvtType VectorScalarType{} = throw "Vector types not supported" diff --git a/SC/Exp.hs b/SC/Exp.hs new file mode 100644 index 0000000..d033cc8 --- /dev/null +++ b/SC/Exp.hs @@ -0,0 +1,170 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +module SC.Exp where + +import qualified Data.Array.Accelerate.AST as A +import Data.Array.Accelerate.AST.LeftHandSide +import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Representation.Shape +import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Type + +import qualified Language.C as C +import SC.Defs +import SC.Monad + + +data CompiledFun aenv t1 t2 = + CompiledFun + C.FunDef -- ^ expression function implementation + (Exprs t1 -> Names t2 -> [C.Expr]) + -- ^ arguments builder. Given: + -- - expressions that compute the direct arguments; + -- - names that the output values should be stored in; + -- returns the list of arguments to be passed to the compiled + -- function. The outputs will be stored by storing to pointers to + -- the given names. + -- The arguments will refer to array variable names found in the + -- original array environment. + +-- | The function must be single-argument. Uncurry if necessary (e.g. for zipWith). +compileFun :: AVarEnv aenv -> A.Fun aenv (t1 -> t2) -> SC (CompiledFun aenv t1 t2) +compileFun aenv (A.Lam lhs (A.Body body)) = do + funname <- genName "expfun_" + (argnames, env) <- genVarsEnv lhs VENil + outnames <- itupmap (\(TypedName t n) -> TypedName (C.TPtr t) n) + <$> genVars (A.expType body) + (usedA, res) <- compileExp' aenv env body + (sts1, retexprs) <- toStExprs (A.expType body) res + let sts2 = genoutstores outnames retexprs + arguments = + map (\(TypedAName t n) -> (t, n)) usedA + ++ map (\(TypedName t n) -> (t, n)) (itupList argnames) + ++ map (\(TypedName t n) -> (t, n)) (itupList outnames) + return $ CompiledFun + (C.ProcDef funname arguments (sts1 ++ sts2)) + (\argexprs destnames -> + map (\(TypedAName _ n) -> C.EVar n) usedA + ++ itupList argexprs + ++ map (\(TypedName _ n) -> C.EPtrTo (C.EVar n)) (itupList destnames)) + where + genoutstores :: Names t -> Exprs t -> [C.Stmt] + genoutstores ITupIgnore _ = [] + genoutstores (ITupSingle (TypedName _ n)) (ITupSingle e) = [C.SStore n (C.ELit "0") e] + genoutstores (ITupPair n1 n2) (ITupPair e1 e2) = genoutstores n1 e1 ++ genoutstores n2 e2 + genoutstores _ _ = error "wat" +compileFun _ _ = error "compileFun: Not single-argument function" + +compileExp :: AVarEnv aenv -> A.Exp aenv t -> SC (CompiledFun aenv () t) +compileExp aenv expr = compileFun aenv (A.Lam (LeftHandSideWildcard TupRunit) (A.Body expr)) + +compileExp' :: AVarEnv aenv -> VarEnv env -> A.OpenExp env aenv t + -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) +compileExp' aenv env = \case + A.Let lhs rhs body -> do + (names, env') <- genVarsEnv lhs env + let sts1 = [C.SDecl t n Nothing | TypedName t n <- itupList names] + (usedA2, sts2) <- fmap (`toStoring` names) <$> compileExp' aenv env rhs + (usedA3, res3) <- compileExp' aenv env' body + return (usedA2 ++ usedA3 + ,fmap (\(sts, exs) -> (sts1 ++ sts2 ++ sts, exs)) res3) + + A.Evar (Var _ idx) -> + return ([], Right ([], ITupSingle (C.EVar (veprj env idx)))) + + A.Pair a b -> do + (usedA1, res1) <- compileExp' aenv env a + (usedA2, res2) <- compileExp' aenv env b + return (usedA1 ++ usedA2, Left (\case + ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 + ITupIgnore -> [] + ITupSingle _ -> error "wat")) + + A.PrimApp (A.PrimAdd _) e -> binary aenv env "+" e + A.PrimApp (A.PrimSub _) e -> binary aenv env "-" e + A.PrimApp (A.PrimMul _) e -> binary aenv env "*" e + A.PrimApp (A.PrimQuot _) e -> binary aenv env "/" e + A.PrimApp (A.PrimRem _) e -> binary aenv env "%" e + + A.Shape (Var _ idx) -> + let (shnames, _) = aveprj aenv idx + buildExprs :: ShNames sh -> Exprs sh + buildExprs ShZ = ITupIgnore + buildExprs (ShS n names) = ITupPair (buildExprs names) (ITupSingle (C.EVar n)) + in return ([], Right ([], buildExprs shnames)) + + A.ToIndex shr she idxe -> do + let build :: ShapeR sh -> Exprs sh -> Exprs sh -> C.Expr + build ShapeRz _ _ = C.ELit "0" + build (ShapeRsnoc ShapeRz) _ (ITupPair _ (ITupSingle idxe')) = idxe' + build (ShapeRsnoc shr') (ITupPair shes' (ITupSingle she')) + (ITupPair idxes' (ITupSingle idxe')) = + C.EOp (C.EOp (build shr' shes' idxes') "*" she') "+" idxe' + build _ _ _ = error "wat" + (usedA1, res1) <- compileExp' aenv env she + (sts1, shes) <- toStExprs (shapeType shr) res1 + (usedA2, res2) <- compileExp' aenv env idxe + (sts2, idxes) <- toStExprs (shapeType shr) res2 + return (usedA1 ++ usedA2, Right (sts1 ++ sts2, ITupSingle (build shr shes idxes))) + + A.Index avar@(Var (ArrayR shr _) _) she -> + compileExp' aenv env $ + A.LinearIndex avar (A.ToIndex shr (A.Shape avar) she) + + A.LinearIndex (Var _ idx) e -> do + temp <- genName "i" + let sts0 = [C.SDecl (C.TInt C.B64) temp Nothing] + (usedA1, sts1) <- fmap (`toStoring` ITupSingle (TypedName (C.TInt C.B64) temp)) + <$> compileExp' aenv env e + let (_, anames) = aveprj aenv idx + usedA = itupList anames ++ usedA1 + return (usedA, Right (sts0 ++ sts1 + ,itupmap (\(TypedAName _ name) -> C.EIndex name (C.EVar temp)) anames)) + + _ -> throw "Unsupported Exp constructor" + where + binary :: AVarEnv aenv -> VarEnv env -> String -> A.OpenExp env aenv (a, b) + -> SC ([TypedAName], Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t)) + binary aenv' env' op e' = do + (usedA, res) <- compileExp' aenv' env' e' + (sts, ITupPair (ITupSingle e1) (ITupSingle e2)) <- + toStExprs (A.expType e') res + return (usedA, Right (sts, ITupSingle (C.EOp e1 op e2))) + +toStExprs :: TypeR t -> Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) +toStExprs ty (Left fun) = do + names <- genVars ty + let sts1 = fun names + return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) +toStExprs _ (Right pair) = return pair + +toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] +toStoring (Left f) = f +toStoring (Right (sts, exs)) = (sts ++) . flip go exs + where + go :: Names t -> Exprs t -> [C.Stmt] + go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] + go (ITupSingle _) _ = error "wat" + go ITupIgnore _ = [] + go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 + go (ITupPair _ _) _ = error "wat" + +genVarsEnv :: A.ELeftHandSide t env env' -> VarEnv env -> SC (Names t, VarEnv env') +genVarsEnv (LeftHandSideWildcard _) env = return (ITupIgnore, env) +genVarsEnv (LeftHandSideSingle ty) env = do + name <- genName "x" + ty' <- cvtType ty + return (ITupSingle (TypedName ty' name), VEPush name env) +genVarsEnv (LeftHandSidePair lhs1 lhs2) env = do + (n1, env1) <- genVarsEnv lhs1 env + (n2, env2) <- genVarsEnv lhs2 env1 + return (ITupPair n1 n2, env2) + +genVars :: TypeR t -> SC (Names t) +genVars TupRunit = return ITupIgnore +genVars (TupRsingle ty) = genVar ty +genVars (TupRpair t1 t2) = ITupPair <$> genVars t1 <*> genVars t2 + +genVar :: ScalarType t -> SC (Names t) +genVar ty = ITupSingle <$> (TypedName <$> cvtType ty <*> genName "x") diff --git a/SC/Monad.hs b/SC/Monad.hs new file mode 100644 index 0000000..c58755f --- /dev/null +++ b/SC/Monad.hs @@ -0,0 +1,25 @@ +{-# LANGUAGE DerivingVia #-} +module SC.Monad where + +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Except +import Control.Monad.Trans.State.Strict + + +newtype SC a = SC (ExceptT String (State Int) a) + deriving (Functor, Applicative, Monad) via (ExceptT String (State Int)) + +instance MonadFail SC where + fail = throw + +evalSC :: SC a -> Either String a +evalSC (SC m) = evalState (runExceptT m) 1 + +genId :: SC Int +genId = SC $ do + value <- lift get + lift (modify (+1)) + return value + +throw :: String -> SC a +throw = SC . throwE |