blob: ae9728a39c33a3704f5e2200bbfcc0b9fd314cf7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid (unMonoid, zero, plus) where
import AST
import CHAD.Types
import Data
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t -> zero t
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
EFst _ e -> EFst ext (unMonoid e)
ESnd _ e -> ESnd ext (unMonoid e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext t (unMonoid e)
EInr _ t e -> EInr ext t (unMonoid e)
ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
zero (STArr SZ t) = EUnit ext (zero t)
zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty")
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
STBool -> ENil ext
zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
plus (STPair t1 t2) a b =
let t = STPair (d2 t1) (d2 t2)
in plusSparse t a b $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
plus (STEither t1 t2) a b =
let t = STEither (d2 t1) (d2 t2)
in plusSparse t a b $
ECase ext (EVar ext t (IS IZ))
(ECase ext (EVar ext t (IS IZ))
(EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
(EError ext t "plus l+r"))
(ECase ext (EVar ext t (IS IZ))
(EError ext t "plus r+l")
(EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
plus (STMaybe t) a b =
plusSparse (d2 t) a b $
plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
plus (STArr n t) a b =
ELet ext a $
ELet ext (weakenExpr WSink b) $
eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
(EVar ext (STArr n (d2 t)) IZ)
(eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
(EVar ext (STArr n (d2 t)) (IS IZ))
(ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
(EVar ext (STArr n (d2 t)) (IS IZ))
(EVar ext (STArr n (d2 t)) IZ)))
plus (STScal t) a b = case t of
STI32 -> ENil ext
STI64 -> ENil ext
STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
STBool -> ENil ext
plus STAccum{} _ _ = error "Accumulators not allowed in input program"
plusSparse :: STy a
-> Ex env (TMaybe a) -> Ex env (TMaybe a)
-> Ex (a : a : env) a
-> Ex env (TMaybe a)
plusSparse t a b adder =
ELet ext b $
EMaybe ext
(EVar ext (STMaybe t) IZ)
(EJust ext
(EMaybe ext
(EVar ext t IZ)
(weakenExpr (WCopy (WCopy WSink)) adder)
(EVar ext (STMaybe t) (IS IZ))))
(weakenExpr WSink a)
onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) -> arg
(STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
(STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
(STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
(STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
(STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
(STArr n t1, SAPArrIdx prj _) ->
let tidx = tTup (sreplicate n tIx)
in ELet ext idx $
EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
(zero t1)
|