blob: 3d5f544c3cf2c5f030de2bd2397e853888f245d3 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid (unMonoid, zero, plus) where
import AST
import Data
-- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them
-- into their concrete implementations.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t e -> zero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
EFst _ e -> EFst ext (unMonoid e)
ESnd _ e -> ESnd ext (unMonoid e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext t (unMonoid e)
EInr _ t e -> EInr ext t (unMonoid e)
ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
ELNil _ t1 t2 -> ELNil ext t1 t2
ELInl _ t e -> ELInl ext t (unMonoid e)
ELInr _ t e -> ELInr ext t (unMonoid e)
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
zero SMTNil _ = ENil ext
zero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
zero (SMTScal t) _ = case t of
STI32 -> EConst ext STI32 0
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
plus SMTNil _ _ = ENil ext
plus (SMTPair t1 t2) a b =
let t = STPair (fromSMTy t1) (fromSMTy t2)
in ELet ext a $
ELet ext (weakenExpr WSink b) $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
plus (SMTLEither t1 t2) a b =
let t = STLEither (fromSMTy t1) (fromSMTy t2)
in ELet ext a $
ELet ext (weakenExpr WSink b) $
ELCase ext (EVar ext t (IS IZ))
(EVar ext t IZ)
(ELCase ext (EVar ext t (IS IZ))
(EVar ext t (IS (IS IZ)))
(ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
(EError ext t "plus l+r"))
(ELCase ext (EVar ext t (IS IZ))
(EVar ext t (IS (IS IZ)))
(EError ext t "plus r+l")
(ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
plus (SMTMaybe t) a b =
ELet ext b $
EMaybe ext
(EVar ext (STMaybe (fromSMTy t)) IZ)
(EJust ext
(EMaybe ext
(EVar ext (fromSMTy t) IZ)
(plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
(EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
(weakenExpr WSink a)
plus (SMTArr _ t) a b =
ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
ELet ext arg $
EVar ext (fromSMTy typ) IZ
(SMTPair t1 t2, SAPFst prj) ->
ELet ext idx $
let tidx = typeOf idx in
ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
let toh = fromSMTy t1 in
EPair ext (EVar ext toh IZ)
(zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
(SMTPair t1 t2, SAPSnd prj) ->
ELet ext idx $
let tidx = typeOf idx in
ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
let toh = fromSMTy t2 in
EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
(EVar ext toh IZ)
(SMTLEither t1 t2, SAPLeft prj) ->
ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
(SMTLEither t1 t2, SAPRight prj) ->
ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
(SMTMaybe t1, SAPJust prj) ->
EJust ext (onehot t1 prj idx arg)
(SMTArr n t1, SAPArrIdx prj) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
(ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
zero t1 (EVar ext (tZeroInfo t1) IZ))
|