summaryrefslogtreecommitdiff
path: root/SC/Afun.hs
blob: 3379cc60a05d0330d2da14d9e507f9567cf91b39 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE StandaloneDeriving #-}
module SC.Afun where

import qualified Data.Array.Accelerate.AST as A
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Type

import qualified Language.C as C
import SC.Acc
import SC.Defs
import SC.Monad


-- | Variable names for a tuple of arrays. Each array is represented in
-- struct-of-arrays form. For example: the type
-- @(Scalar Double, Matrix (Int, Float))@, which is internally represented as
-- @(Array () Double, Array (((), Int), Int) (Int, Float))@, would be
-- described as follows: (the variable names will differ)
--
-- > CATNPair (CATNArray CSNNil
-- >                     (CANName TDouble (Name "a")))
-- >          (CATNArray (CSNSnoc (CSNSnoc CSNNil (Name "n1")) (Name "n2"))
-- >                     (CANPair (CANName (TInt B64) (Name "b"))
-- >                              (CANName TFloat (Name "c"))))
--
-- Suppose that the Accelerate function in question has return type
-- @Vector Double@, which is to say @Array ((), Int) Double@, with description:
--
-- > CATNArray (CSNSnoc CSNNil (Name "m"))
-- >           (CANName TDouble (Name "r"))
--
-- Then its C function definition would look as follows:
--
-- > void function(double *a,
-- >               int64_t n1, int64_t n2, int64_t *b, int64_t *c,
-- >               int64_t m, double **r);
--
-- Note that the first input array array here has zero shape arguments because
-- it is zero-dimensional.
data CArrTupNames a where
    CATNPair :: CArrTupNames a -> CArrTupNames b -> CArrTupNames (a, b)
    CATNArray :: CShNames sh -> CArrNames sh a -> CArrTupNames (Array sh a)
    CATNNil :: CArrTupNames ()
deriving instance Show (CArrTupNames a)

-- | Names for the shape of an array. See 'CArrTupNames' for more information.
--
-- Note that the names in this structure are are to be interpreted as variables
-- of type @int64_t@.
data CShNames sh where
    CSNSnoc :: CShNames sh -> C.Name -> CShNames (sh, Int)
    CSNNil :: CShNames ()
deriving instance Show (CShNames a)

-- | Names for a single array. See 'CArrTupNames' for more information.
--
-- Note that the 'C.Type' in 'CANName' is the /element/ type of the array.
data CArrNames sh a where
    CANPair :: CArrNames sh a -> CArrNames sh b -> CArrNames sh (a, b)
    CANName :: C.Type -> C.Name -> CArrNames sh a
    CANNil :: CArrNames sh ()
deriving instance Show (CArrNames sh a)

-- | The function passed should have exactly one argument (that may consist of
-- multiple arrays in a tuple, of course).
--
-- The result consists of:
-- 1. An array of auxiliary function definitions that the program needs, on top
--    of the prelude.
-- 2. The function that implements the top-level Accelerate array function.
-- 3. The variable names corresponding to the components of the argument.
-- 4. The variable names corresponding to the components of the result. These
--    are double-pointer arguments to the function in (2.).
--
-- For an example, see the documentation of 'CArrTupNames'.
compileAfun1 :: C.Name
             -> A.Afun (a -> b)
             -> SC ([C.FunDef], C.FunDef, CArrTupNames a, CArrTupNames b)
compileAfun1 procname (A.Alam lhs (A.Abody acc)) = do
    (argnames, aenv) <- genVarsAEnv lhs AVENil
    destnames <- genAVarsTup (A.arraysR acc)
    let destShapeDeclSts = [C.SDecl t n Nothing
                           | TypedName t n <- fst (tupanamesList destnames)]
    outnames <- genAVarsTup (A.arraysR acc)
    (auxdefs, stmts) <- compileCommands <$> compileAcc' aenv destnames acc
    return (auxdefs
           ,C.ProcDef procname
                      (map (\case Left (TypedName t n) -> (t, n)
                                  Right (TypedAName t n) -> (t, n))
                           (tupanamesList' argnames)
                       ++
                       map (\case Left (TypedName t n) -> (C.TPtr t, n)
                                  Right (TypedAName t n) -> (C.TPtr t, n))
                           (tupanamesList' outnames))
                      (destShapeDeclSts ++
                        stmts ++
                        [C.SStore outn (C.ELit "0") (C.EVar destn)
                        | (outn, destn) <- zipOutSrcNamesT outnames destnames])
           ,makeCArrTupNames (lhsToTupR lhs) (\(C.TPtr t) -> t) argnames
           ,makeCArrTupNames (A.arraysR acc) (\(C.TPtr t) -> t) outnames)
  where
    makeCArrTupNames :: ArraysR a -> (C.Type -> C.Type) -> TupANames a -> CArrTupNames a
    makeCArrTupNames (TupRpair t1 t2) typefun (ANPair an1 an2) =
        CATNPair (makeCArrTupNames t1 typefun an1) (makeCArrTupNames t2 typefun an2)
    makeCArrTupNames (TupRsingle (ArrayR _ t)) typefun (ANArray shn ans) =
        CATNArray (makeCShNames shn) (makeCArrNames t typefun ans)
    makeCArrTupNames TupRunit _ ANIgnore = CATNNil
    makeCArrTupNames _ _ ANIgnore = error "Ignore of non-nil element in generated names"

    makeCShNames :: ShNames sh -> CShNames sh
    makeCShNames ShZ = CSNNil
    makeCShNames (ShS ns n) = CSNSnoc (makeCShNames ns) n

    makeCArrNames :: TypeR a -> (C.Type -> C.Type) -> ANames a -> CArrNames sh a
    makeCArrNames (TupRpair t1 t2) typefun (ITupPair an1 an2) =
        CANPair (makeCArrNames t1 typefun an1) (makeCArrNames t2 typefun an2)
    makeCArrNames (TupRsingle _) typefun (ITupSingle (TypedAName ty n)) =
        CANName (typefun ty) n
    makeCArrNames TupRunit _ ITupIgnore = CANNil
    makeCArrNames _ _ ITupIgnore = error "Ignore of non-nil element in generated names"
    makeCArrNames _ _ _ = error "Invalid GADTs"

    zipOutSrcNamesT :: TupANames t -> TupANames t -> [(C.Name, C.Name)]
    zipOutSrcNamesT ANIgnore _ = []
    zipOutSrcNamesT _ ANIgnore = error "Ignore in source names but not in out names"
    zipOutSrcNamesT (ANArray shn ns) (ANArray shn' ns') =
        zipWith (\(TypedName _ n) (TypedName _ n') -> (n, n'))
                (shnamesList shn) (shnamesList shn')
           ++ zipOutSrcNames ns ns'
    zipOutSrcNamesT (ANPair a b) (ANPair a' b') = zipOutSrcNamesT a a' ++ zipOutSrcNamesT b b'

    zipOutSrcNames :: ANames t -> ANames t -> [(C.Name, C.Name)]
    zipOutSrcNames ITupIgnore _ = []
    zipOutSrcNames _ ITupIgnore = error "Ignore in source names but not in out names"
    zipOutSrcNames (ITupPair a b) (ITupPair a' b') = zipOutSrcNames a a' ++ zipOutSrcNames b b'
    zipOutSrcNames (ITupSingle (TypedAName _ n)) (ITupSingle (TypedAName _ n')) = [(n, n')]
    zipOutSrcNames _ _ = error "Invalid GADTs"
compileAfun1 _ _ = throw "Not an array function with exactly one argument"