From f42d7f4562ea2e5c9ef634665952e38630f17ae4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 22 Sep 2021 21:13:53 +0200 Subject: No more errors, but lots unimplemented --- SC/Acc.hs | 82 ++++++++++++++++++++++++++++++++------------------------------ SC/Defs.hs | 2 +- 2 files changed, 44 insertions(+), 40 deletions(-) (limited to 'SC') diff --git a/SC/Acc.hs b/SC/Acc.hs index 955c6da..1424f52 100644 --- a/SC/Acc.hs +++ b/SC/Acc.hs @@ -6,6 +6,7 @@ import qualified Data.Array.Accelerate.AST as A import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Representation.Array +import Data.Array.Accelerate.Representation.Shape hiding (zip) import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Bifunctor @@ -30,12 +31,12 @@ data Command insertDeallocs :: [Command] -> [Command] insertDeallocs cmds = - let allocated = Set.fromList [n | CAlloc _ _ n _ <- cmds] - `Set.union` Set.fromList [n | CKeepalive n <- cmds] + let collectable = Set.fromList [n | CAlloc _ _ n _ <- cmds] + `Set.difference` Set.fromList [n | CKeepalive n <- cmds] in fst $ foldr (\cmd (rest, done) -> case cmd of CChunk _ _ used -> - let todealloc = filter (\n -> n `Set.member` allocated && + let todealloc = filter (\n -> n `Set.member` collectable && n `Set.notMember` done) used in (cmd : map CDealloc todealloc ++ rest @@ -43,7 +44,7 @@ insertDeallocs cmds = CAlloc _ _ name _ | name `Set.notMember` done -> (rest, done) -- unused alloc | otherwise -> (cmd : rest, Set.delete name done) - CKeepalive _ -> (rest, done) -- already handled above in @allocated@ + CKeepalive _ -> (rest, done) -- already handled above in @collectable@ CDealloc _ -> error "insertDeallocs: CDealloc found") ([], mempty) cmds @@ -67,58 +68,61 @@ compilePAcc' :: AVarEnv aenv -> TupANames t -> A.PreOpenAcc A.OpenAcc aenv t -> compilePAcc' aenv destnames = \case A.Alet lhs rhs body -> do (names, aenv') <- genVarsAEnv lhs aenv - let sts1 = [C.SDecl t n Nothing | TypedAName t n <- itupList names] - let cmds1 = [CChunk [] sts1 []] + let sts1sh = [C.SDecl t n Nothing | TypedName t n <- fst (tupanamesList names)] + sts1arr = [C.SDecl t n Nothing | TypedAName t n <- snd (tupanamesList names)] + let cmds1 = [CChunk [] (sts1sh ++ sts1arr) []] cmds2 <- compileAcc' aenv names rhs cmds3 <- compileAcc' aenv' destnames body return (cmds1 ++ cmds2 ++ cmds3) - A.Avar (Var _ idx) -> - return (Right ([], ITupSingle (C.EVar (aveprj aenv idx)))) - - A.Apair a b -> do - res1 <- compileAcc' aenv a - res2 <- compileAcc' aenv b - return (Left (\case - ITupPair n1 n2 -> toStoring res1 n1 ++ toStoring res2 n2 - ITupIgnore -> [] - ITupSingle _ -> error "wat")) + A.Avar (Var _ idx) + | ANArray destshnames destarrnames <- destnames -> do + let (shnames, arrnames) = aveprj aenv idx + sts = [C.SAsg destn (C.EVar srcn) + | (TypedName _ destn, TypedName _ srcn) <- zip (shnamesList destshnames) (shnamesList shnames)] + ++ + [C.SAsg destn (C.EVar srcn) + | (destn, srcn) <- zipDestSrcNames destarrnames arrnames] + usedA = map (\(TypedAName _ n) -> n) (itupList arrnames) + return [CChunk [] sts usedA] + + A.Apair a b + | ANPair destnames1 destnames2 <- destnames -> do + res1 <- compileAcc' aenv destnames1 a + res2 <- compileAcc' aenv destnames2 b + return (res1 ++ res2) _ -> throw "Unsupported Acc constructor" - where - toStExprs :: TypeR t -> Either (ANames t -> [C.Stmt]) ([C.Stmt], Exprs t) -> SC ([C.Stmt], Exprs t) - toStExprs ty (Left fun) = do - names <- genAVars ty - let sts1 = fun names - return (sts1, itupmap (\(TypedName _ n) -> C.EVar n) names) - toStExprs _ (Right pair) = return pair - - toStoring :: Either (Names t -> [C.Stmt]) ([C.Stmt], Exprs t) -> Names t -> [C.Stmt] - toStoring (Left f) = f - toStoring (Right (sts, exs)) = (sts ++) . flip go exs - where - go :: Names t -> Exprs t -> [C.Stmt] - go (ITupSingle (TypedName _ name)) (ITupSingle ex) = [C.SAsg name ex] - go (ITupSingle _) _ = error "wat" - go ITupIgnore _ = [] - go (ITupPair ns1 ns2) (ITupPair es1 es2) = go ns1 es1 ++ go ns2 es2 - go (ITupPair _ _) _ = error "wat" + +zipDestSrcNames :: ANames e -> ANames e -> [(C.Name, C.Name)] +zipDestSrcNames ITupIgnore _ = [] +zipDestSrcNames _ ITupIgnore = error "Ignore in source names where there is none in the destination names" +zipDestSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')] +zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b' +zipDestSrcNames _ _ = error "wat" genVarsAEnv :: A.ALeftHandSide t aenv aenv' -> AVarEnv aenv -> SC (TupANames t, AVarEnv aenv') genVarsAEnv (LeftHandSideWildcard _) env = return (ANIgnore, env) -genVarsAEnv (LeftHandSideSingle (ArrayR _ ty)) env = do - name <- genName "a" - ty' <- cvtType ty - return (ITupSingle (TypedAName ty' name), AVEPush _ name env) +genVarsAEnv (LeftHandSideSingle (ArrayR sht ty)) env = do + shnames <- genShNames sht + names <- genAVars ty + return (ANArray shnames names, AVEPush shnames names env) genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do (n1, env1) <- genVarsAEnv lhs1 env (n2, env2) <- genVarsAEnv lhs2 env1 - return (ITupPair n1 n2, env2) + return (ANPair n1 n2, env2) genAVars :: TypeR t -> SC (ANames t) genAVars TupRunit = return ITupIgnore genAVars (TupRsingle ty) = genAVar ty genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 +genShNames :: ShapeR sh -> SC (ShNames sh) +genShNames ShapeRz = return ShZ +genShNames (ShapeRsnoc sht) = do + name <- genName "n" + names <- genShNames sht + return (ShS name names) + genAVar :: ScalarType t -> SC (ANames t) genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a") diff --git a/SC/Defs.hs b/SC/Defs.hs index 685d408..bb8e03f 100644 --- a/SC/Defs.hs +++ b/SC/Defs.hs @@ -59,7 +59,7 @@ type ANames = ITup TypedAName type Exprs = ITup C.Expr --- Type is a pointer type +-- Type is the pointer type of the type that this name is supposed to be according to the type index. data TypedAName = TypedAName C.Type Name data TupANames t where -- cgit v1.2.3-70-g09d2