summaryrefslogtreecommitdiff
path: root/SC/Acc.hs
diff options
context:
space:
mode:
Diffstat (limited to 'SC/Acc.hs')
-rw-r--r--SC/Acc.hs38
1 files changed, 24 insertions, 14 deletions
diff --git a/SC/Acc.hs b/SC/Acc.hs
index b50bf24..5ae2532 100644
--- a/SC/Acc.hs
+++ b/SC/Acc.hs
@@ -100,7 +100,8 @@ compilePAcc' aenv destnames = \case
CompiledFun funFD funArgbuilder usedAfun <- compileFun aenv fun
tempnames <- genVars restype
loops <- enumShapeNested destshnames $ \idxnames linidxexpr -> concat
- [[C.SCall (C.fundefName funFD)
+ [[C.SDecl t n Nothing | TypedName t n <- itupList tempnames]
+ ,[C.SCall (C.fundefName funFD)
(funArgbuilder (itupEvars (fromShNames idxnames)) tempnames)]
,[C.SStore arrname linidxexpr (C.EVar tempname)
| (arrname, tempname) <- zipDestSrcNamesAE destarrnames tempnames]]
@@ -108,42 +109,46 @@ compilePAcc' aenv destnames = \case
[[CChunk [sheFD]
[C.SCall (C.fundefName sheFD)
(sheArgbuilder ITupIgnore (fromShNames destshnames))]
- (map (\(TypedAName _ n) -> n) usedAshe)]
+ (concatMap (\(SomeArray _ ans) ->
+ map (\(TypedAName _ n) -> n) (itupList ans))
+ usedAshe)]
,[CAlloc [] eltty n (C.StExpr [] (computeSize destshnames))
| TypedAName arrty n <- itupList destarrnames
, let C.TPtr eltty = arrty]
,[CChunk [funFD]
loops
- (map (\(TypedAName _ n) -> n) (itupList destarrnames ++ usedAfun))]]
+ (map (\(TypedAName _ n) -> n)
+ (itupList destarrnames
+ ++ concatMap (\(SomeArray _ ans) -> itupList ans) usedAfun))]]
_ -> throw "Unsupported Acc constructor"
-- | Returns an expression of type int64_t
computeSize :: ShNames sh -> C.Expr
computeSize ShZ = C.ELit "1LL"
-computeSize (ShS n ShZ) = C.EVar n
-computeSize (ShS n ns) = C.EOp (C.EVar n) "*" (computeSize ns)
+computeSize (ShS ShZ n) = C.EVar n
+computeSize (ShS ns n) = C.EOp (computeSize ns) "*" (C.EVar n)
-- | Given size variables and index variables, returns an expression of type int64_t
linearIndexExpr :: ShNames sh -> ShNames sh -> C.Expr
linearIndexExpr ShZ ShZ = C.ELit "1LL"
-linearIndexExpr (ShS _ ShZ) (ShS i ShZ) = C.EVar i
-linearIndexExpr (ShS n ns) (ShS i is) =
+linearIndexExpr (ShS ShZ _) (ShS ShZ i) = C.EVar i
+linearIndexExpr (ShS ns n) (ShS is i) =
C.EOp (C.EOp (linearIndexExpr ns is) "*" (C.EVar n)) "+" (C.EVar i)
-zipDestSrcNames :: ITup C.Name e -> ITup C.Name e -> [(C.Name, C.Name)]
+zipDestSrcNames :: ITup C.Name t -> ITup C.Name t -> [(C.Name, C.Name)]
zipDestSrcNames ITupIgnore _ = []
zipDestSrcNames _ ITupIgnore = error "Ignore in source names but not in destination names"
zipDestSrcNames (ITupSingle n) (ITupSingle n') = [(n, n')]
zipDestSrcNames (ITupPair a b) (ITupPair a' b') = zipDestSrcNames a a' ++ zipDestSrcNames b b'
zipDestSrcNames _ _ = error "wat"
-zipDestSrcNamesAA :: ANames e -> ANames e -> [(C.Name, C.Name)]
+zipDestSrcNamesAA :: ANames t -> ANames t -> [(C.Name, C.Name)]
zipDestSrcNamesAA ns1 ns2 =
zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
(itupmap (\(TypedAName _ n) -> n) ns2)
-zipDestSrcNamesAE :: ANames e -> Names e -> [(C.Name, C.Name)]
+zipDestSrcNamesAE :: ANames t -> Names t -> [(C.Name, C.Name)]
zipDestSrcNamesAE ns1 ns2 =
zipDestSrcNames (itupmap (\(TypedAName _ n) -> n) ns1)
(itupmap (\(TypedName _ n) -> n) ns2)
@@ -159,7 +164,7 @@ enumShapeNested sizenames fun = do
idxnames <- genShNames (shNamesShape sizenames)
let makeLoops :: ShNames sh -> ShNames sh -> [C.Stmt] -> [C.Stmt]
makeLoops ShZ ShZ body = body
- makeLoops (ShS n ns) (ShS i is) body =
+ makeLoops (ShS ns n) (ShS is i) body =
makeLoops ns is [C.SFor (C.TInt C.B64) i (C.ELit "0") (C.EVar n) body]
return (makeLoops sizenames idxnames (fun idxnames (linearIndexExpr sizenames idxnames)))
@@ -174,6 +179,11 @@ genVarsAEnv (LeftHandSidePair lhs1 lhs2) env = do
(n2, env2) <- genVarsAEnv lhs2 env1
return (ANPair n1 n2, env2)
+genAVarsTup :: ArraysR t -> SC (TupANames t)
+genAVarsTup TupRunit = return ANIgnore
+genAVarsTup (TupRsingle (ArrayR sht ty)) = ANArray <$> genShNames sht <*> genAVars ty
+genAVarsTup (TupRpair t1 t2) = ANPair <$> genAVarsTup t1 <*> genAVarsTup t2
+
genAVars :: TypeR t -> SC (ANames t)
genAVars TupRunit = return ITupIgnore
genAVars (TupRsingle ty) = genAVar ty
@@ -182,9 +192,9 @@ genAVars (TupRpair t1 t2) = ITupPair <$> genAVars t1 <*> genAVars t2
genShNames :: ShapeR sh -> SC (ShNames sh)
genShNames ShapeRz = return ShZ
genShNames (ShapeRsnoc sht) = do
- name <- genName "n"
names <- genShNames sht
- return (ShS name names)
+ name <- genName "n"
+ return (ShS names name)
genAVar :: ScalarType t -> SC (ANames t)
-genAVar ty = ITupSingle <$> (TypedAName <$> cvtType ty <*> genName "a")
+genAVar ty = ITupSingle <$> (TypedAName <$> fmap C.TPtr (cvtType ty) <*> genName "a")