blob: 1675dabd28fe5eeaf972f9e4b13fe4fafe0245f3 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
module AST.UnMonoid where
import AST
import CHAD.Types
import Data
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero t -> zero t
EPlus t a b -> plus t a b
EOneHot t i a b -> _ t i a b
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
EFst _ e -> EFst ext (unMonoid e)
ESnd _ e -> ESnd ext (unMonoid e)
ENil _ -> ENil ext
EInl _ t e -> EInl ext t (unMonoid e)
EInr _ t e -> EInr ext t (unMonoid e)
ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
ENothing _ t -> ENothing ext t
EJust _ e -> EJust ext (unMonoid e)
EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
EWith a b -> EWith (unMonoid a) (unMonoid b)
EAccum n a b e -> EAccum n (unMonoid a) (unMonoid b) (unMonoid e)
EError t s -> EError t s
zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
STBool -> ENil ext
zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
plus (STPair t1 t2) a b =
let t = STPair (d2 t1) (d2 t2)
in plusSparse t a b $
EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
(EFst ext (EVar ext t IZ)))
(plus t2 (ESnd ext (EVar ext t (IS IZ)))
(ESnd ext (EVar ext t IZ)))
plus (STEither t1 t2) a b =
let t = STEither (d2 t1) (d2 t2)
in plusSparse t a b $
ECase ext (EVar ext t (IS IZ))
(ECase ext (EVar ext t (IS IZ))
(EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
(EError t "plus l+r"))
(ECase ext (EVar ext t (IS IZ))
(EError t "plus r+l")
(EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
plus (STMaybe t) a b =
plusSparse (d2 t) a b $
plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)
plus (STArr n t) a b =
ELet ext a $
ELet ext (weakenExpr WSink b) $
ECase
plus (STScal t) a b = case t of
STI32 -> ENil ext
STI64 -> ENil ext
STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
STBool -> ENil ext
plus STAccum{} _ _ = error "Accumulators not allowed in input program"
plusSparse :: STy a
-> Ex env (TMaybe a) -> Ex env (TMaybe a)
-> Ex (a : a : env) a
-> Ex env (TMaybe a)
plusSparse t a b adder =
ELet ext b $
EMaybe ext
(EVar ext (STMaybe t) IZ)
(EJust ext
(EMaybe ext
(EVar ext t IZ)
(weakenExpr (WCopy (WCopy WSink)) adder)
(EVar ext (STMaybe t) (IS IZ))))
(weakenExpr WSink a)
|