summaryrefslogtreecommitdiff
path: root/src/CHAD/Types.hs
blob: e8ec0c98daad97afca75ee5da9c9d0fb58ff03c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where

import AST.Types
import Data


type family D1 t where
  D1 TNil = TNil
  D1 (TPair a b) = TPair (D1 a) (D1 b)
  D1 (TEither a b) = TEither (D1 a) (D1 b)
  D1 (TMaybe a) = TMaybe (D1 a)
  D1 (TArr n t) = TArr n (D1 t)
  D1 (TScal t) = TScal t

type family D2 t where
  D2 TNil = TNil
  D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
  D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
  D2 (TMaybe t) = TMaybe (D2 t)
  D2 (TArr n t) = TArr n (D2 t)
  D2 (TScal t) = D2s t

type family D2s t where
  D2s TI32 = TNil
  D2s TI64 = TNil
  D2s TF32 = TScal TF32
  D2s TF64 = TScal TF64
  D2s TBool = TNil

type family D1E env where
  D1E '[] = '[]
  D1E (t : env) = D1 t : D1E env

type family D2E env where
  D2E '[] = '[]
  D2E (t : env) = D2 t : D2E env

type family D2AcE env where
  D2AcE '[] = '[]
  D2AcE (t : env) = TAccum t : D2AcE env

d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"

d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env

d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
d2 (STMaybe t) = STMaybe (d2 t)
d2 (STArr n t) = STArr n (d2 t)
d2 (STScal t) = case t of
  STI32 -> STNil
  STI64 -> STNil
  STF32 -> STScal STF32
  STF64 -> STScal STF64
  STBool -> STNil
d2 STAccum{} = error "Accumulators not allowed in input program"

d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
d2e (t `SCons` ts) = d2 t `SCons` d2e ts

d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts


data CHADConfig = CHADConfig
  { -- | D[let] will bind variables containing arrays in accumulator mode.
    chcLetArrayAccum :: Bool
  , -- | D[case] will bind variables containing arrays in accumulator mode.
    chcCaseArrayAccum :: Bool
  , -- | Introduce top-level arguments containing arrays in accumulator mode.
    chcArgArrayAccum :: Bool
  }
  deriving (Show)

defaultConfig :: CHADConfig
defaultConfig = CHADConfig
  { chcLetArrayAccum = False
  , chcCaseArrayAccum = False
  , chcArgArrayAccum = False
  }

chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
                  , chcCaseArrayAccum = True
                  , chcArgArrayAccum = True }


------------------------------------ LEMMAS ------------------------------------

indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl