diff options
Diffstat (limited to 'SC/Acc.hs')
-rw-r--r-- | SC/Acc.hs | 38 |
1 files changed, 24 insertions, 14 deletions
@@ -100,7 +100,8 @@ compilePAcc' aenv destnames = \case CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun tempnames <- genVars restype loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat - [[C.SCall (C.fundefName funFD) + [[C.SDecl t n Nothing | TypedName t n <- itupList tempnames] + ,[C.SCall (C.fundefName funFD) (funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)] ,[C.SStore arrname linidxexpr (C.EVar tempname) | (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]] @@ -108,42 +109,46 @@ compilePAcc' aenv destnames = \case [[CChunk [sheFD] [C.SCall (C.fundefName sheFD) (sheArgbuilder ITupIgnore (fromShNames destshnames))] - (map (\(TypedAName _ n) -> n) usedAshe)] + (concatMap (\(SomeArray _ ans) -> + map (\(TypedAName _ n) -> n) (itupList ans)) + usedAshe)] ,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames)) | TypedAName arrty n <- itupList destarrnames , let C.TPtr eltty = arrty] ,[CChunk [funFD] loops - (map (\(TypedAName _ n) -> n) (itupList destarrnames ++ usedAfun))]] + (map (\(TypedAName _ n) -> n) + (itupList destarrnames + ++ concatMap (\(SomeArray _ ans) -> itupList ans) usedAfun))]] _ -> throw "Unsupported Acc constructor" -- | Returns an expression of type int64_t computeSize :: ShNames sh -> C.Expr computeSize ShZ = C.ELit "1LL" -computeSize (ShS n ShZ) = C.EVar n -computeSize (ShS n ns) = C.EOp (C.EVar n) "*" (computeSize ns) +computeSize (ShS ShZ n) = C.EVar n +computeSize (ShS ns n) = C.EOp (computeSize ns) "*" (C.EVar n) -- | Given size variables and index variables, returns an expression of type int64_t linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr linearIndexExpr ShZ ShZ = C.ELit "1LL" -linearIndexExpr (ShS _ ShZ) (ShS i ShZ) = C.EVar i -linearIndexExpr (ShS n ns) (ShS i is) = +linearIndexExpr (ShS ShZ _) (ShS ShZ i) = C.EVar i +linearIndexExpr (ShS ns n) (ShS is i) = C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i) -zipDestSrcNames :: ITup C.Name e -> ITup C.Name e -> [(C.Name, C.Name)] +zipDestSrcNames :: ITup C.Name t -> ITup C.Name t -> [(C.Name, C.Name)] zipDestSrcNames ITupIgnore _ = [] zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names" zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')] zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b' zipDestSrcNames _ _ = error "wat" -zipDestSrcNamesAA :: ANames e -> ANames e -> [(C.Name, C.Name)] +zipDestSrcNamesAA :: ANames t -> ANames t -> [(C.Name, C.Name)] zipDestSrcNamesAA ns1 ns2 = zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1) (itupmap (\(TypedAName _ n) -> n) ns2) -zipDestSrcNamesAE :: ANames e -> Names e -> [(C.Name, C.Name)] +zipDestSrcNamesAE :: ANames t -> Names t -> [(C.Name, C.Name)] zipDestSrcNamesAE ns1 ns2 = zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1) (itupmap (\(TypedName _ n) -> n) ns2) @@ -159,7 +164,7 @@ enumShapeNested sizenames fun = do idxnames <- genShNames (shNamesShape sizenames) let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt] makeLoops ShZ ShZ body = body - makeLoops (ShS n ns) (ShS i is) body = + makeLoops (ShS ns n) (ShS is i) body = makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body] return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames))) @@ -174,6 +179,11 @@ genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do (n2, env2) <- genVarsAEnv lhs2 env1 return (ANPair n1 n2, env2) +genAVarsTup :: ArraysR t -> SC (TupANames t) +genAVarsTup TupRunit = return ANIgnore +genAVarsTup (TupRsingle (ArrayR sht ty)) = ANArray <$> genShNames sht <*> genAVars ty +genAVarsTup (TupRpair t1 t2) = ANPair <$> genAVarsTup t1 <*> genAVarsTup t2 + genAVars :: TypeR t -> SC (ANames t) genAVars TupRunit = return ITupIgnore genAVars (TupRsingle ty) = genAVar ty @@ -182,9 +192,9 @@ genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2 genShNames :: ShapeR sh -> SC (ShNames sh) genShNames ShapeRz = return ShZ genShNames (ShapeRsnoc sht) = do - name <- genName "n" names <- genShNames sht - return (ShS name names) + name <- genName "n" + return (ShS names name) genAVar :: ScalarType t -> SC (ANames t) -genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a") +genAVar ty = ITupSingle <$> (TypedAName <$> fmap C.TPtr (cvtType ty) <*> genName "a") |