1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
|
{-# LANGUAGE CPP #-}
module Simplify (simplify) where
import Data.List
import qualified Data.Map.Strict as Map
import AST
import Utility
import Debug
import PrettyPrint
tracex :: (Show a) => String -> a -> a
tracex s x = trace (s ++ ": " ++ show x) x
tracexp :: (PrettyPrint a) => String -> a -> a
tracexp s x = trace (s ++ ": " ++ prettyPrint x) x
simplify :: AST -> AST
simplify = tracex "last canonicaliseOrder" . canonicaliseOrder
. (fixpoint $ tracex "applyPatterns " . applyPatterns
. tracex "canonicaliseOrder" . canonicaliseOrder
. tracex "flattenSums " . flattenSums
. tracex "foldNumbers " . foldNumbers)
. tracex "first flattenSums" . flattenSums
flattenSums :: AST -> AST
flattenSums node = case node of
(Negative n) -> Negative $ flattenSums n
(Reciprocal n) -> Reciprocal $ flattenSums n
(Apply name args) -> Apply name $ map flattenSums args
(Sum args) -> case length args of
0 -> Number 0
1 -> flattenSums $ args !! 0
otherwise -> Sum $ concat $ map (listify . flattenSums) args
where
listify (Sum args) = args
listify node = [node]
(Product args) -> case length args of
0 -> Number 1
1 -> flattenSums $ args !! 0
otherwise -> Product $ concat $ map (listify . flattenSums) args
where
listify (Product args) = args
listify node = [node]
_ -> node
foldNumbers :: AST -> AST
foldNumbers node = case node of
(Negative n) -> let fn = foldNumbers n in case fn of
(Number x) -> Number (-x)
(Negative n2) -> n2
(Product args) -> Product $ Number (-1) : args
_ -> Negative $ fn
(Reciprocal n) -> let fn = foldNumbers n in case fn of
(Number x) -> Number (1/x)
(Negative n) -> Negative $ Reciprocal fn
(Reciprocal n2) -> n2
_ -> Reciprocal $ fn
(Apply name args) -> let fargs = map foldNumbers args
in case name of
"pow" -> if all astIsNumber fargs
then Number $ astFromNumber (fargs!!0) ** astFromNumber (fargs!!1)
else Apply "pow" fargs
_ -> Apply name fargs
(Sum args) -> Sum $ dofoldnums sum args 0
(Product args) -> dofoldnegsToProd $ dofoldnums product args 1
_ -> node
where
dofoldnums func args zerovalue =
let foldedArgs = map foldNumbers args
(nums,notnums) = partition astIsNumber foldedArgs
foldvalue = func $ map (\(Number n) -> n) nums
in case length nums of
x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums
otherwise -> foldedArgs
dofoldnegsToProd args =
let foldedArgs = map foldNumbers args
(negs,notnegs) = partition isneg foldedArgs
isneg (Negative _) = True
isneg (Number n) = n < 0
isneg _ = False
unneg (Negative n) = n
unneg (Number n) = Number $ abs n
unneg n = n
unnegged = map unneg negs ++ notnegs
in case length negs of
x | x < 2 -> Product args
| even x -> Product unnegged
| odd x -> Product $ Number (-1) : unnegged
canonicaliseOrder :: AST -> AST
canonicaliseOrder node = case node of
(Number _) -> node
(Variable _) -> node
(Sum args) -> Sum $ sort args
(Product args) -> Product $ sort args
(Negative n) -> Negative $ canonicaliseOrder n
(Reciprocal n) -> Reciprocal $ canonicaliseOrder n
(Apply name args) -> Apply name $ map canonicaliseOrder args
(Capture _) -> node
(CaptureTerm _) -> node
(CaptureConstr _ _) -> node
patterndb :: [(AST,AST)]
patterndb = [
(Reciprocal $ Product [Reciprocal (CaptureTerm "x"),Capture "rest"], -- 1/(1/x * [rest]) -> x * 1/[rest]
Product [Capture "x",Reciprocal $ Capture "rest"]),
(Product [CaptureTerm "x",Reciprocal (CaptureTerm "x"),Capture "rest"], -- x * 1/x * [rest] -> [rest]
Capture "rest"),
(Product [CaptureTerm "x", -- x * 1/(x*[rrest]) * [rest] -> [rest] * 1/[rrest]
Reciprocal (Product [CaptureTerm "x",Capture "rrest"]),
Capture "rest"],
Product [Capture "rest",Reciprocal (Capture "rrest")]),
(Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"], -- x + x + [rest] -> 2*x + [rest]
Sum [Product [Number 2,Capture "x"],Capture "rest"]),
(Sum [CaptureTerm "x", -- x + n*x + [rest] -> (1+n)*x + [rest]
Product [CaptureConstr "n" (Number undefined),Capture "x"],
Capture "rest"],
Sum [Product [Sum [Number 1,Capture "n"],
Capture "x"],
Capture "rest"]),
(Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"], -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest]
Product [CaptureConstr "n2" (Number undefined),Capture "x"],
Capture "rest"],
Sum [Product [Sum [Capture "n1",Capture "n2"],
Capture "x"],
Capture "rest"]),
(Product [CaptureTerm "x", -- x*x*[rest] -> x^2*[rest]
CaptureTerm "x",
Capture "rest"],
Product [Apply "pow" [Capture "x",Number 2],
Capture "rest"]),
(Product [CaptureTerm "x", -- x*x^n*[rest] -> x^(n+1)*[rest]
Apply "pow" [CaptureTerm "x",CaptureTerm "n"],
Capture "rest"],
Product [Apply "pow" [Capture "x",Sum [Capture "n",Number 1]],
Capture "rest"]),
(Product [Apply "pow" [CaptureTerm "x",CaptureTerm "a"], -- x^a*x^b*[rest] -> x^(a+b)*[rest]
Apply "pow" [CaptureTerm "x",CaptureTerm "b"],
Capture "rest"],
Product [Apply "pow" [Capture "x",Sum [Capture "a",Capture "b"]],
Capture "rest"]),
(Apply "pow" [Apply "pow" [CaptureTerm "x",CaptureTerm "n"],CaptureTerm "m"], -- (x^n)^m -> x^(n*m)
Apply "pow" [Capture "x",Product [Capture "n",Capture "m"]]),
(Apply "pow" [CaptureTerm "x",Number 1], -- x^1 -> x
Capture "x"),
(Product [Number 0,Capture "rest"], -- 0*[rest] -> 0
Number 0),
(Product [Number 1,Capture "rest"], -- 1*[rest] -> [rest]
Capture "rest"),
(Apply "d" [CaptureConstr "n" (Number undefined),Capture "x"], -- d(n,x) -> 0
Number 0),
(Apply "d" [CaptureConstr "x" (Variable undefined),CaptureTerm "x"], -- d(x,x) -> 1
Number 1),
(Apply "d" [Apply "pow" [CaptureConstr "x" (Variable undefined),CaptureTerm "n"], -- d(x^n,x) -> n*x^(n-1)
CaptureTerm "x"],
Product [Capture "n",Apply "pow" [Capture "x",Sum [Capture "n",Number (-1)]]]),
(Apply "d" [Sum [CaptureTerm "a",Capture "b"],CaptureTerm "x"], -- d(a+[b],x) -> d(a,x) + d([b],x)
Sum [Apply "d" [Capture "a",Capture "x"],Apply "d" [Capture "b",Capture "x"]]),
(Apply "d" [Product [CaptureTerm "a",Capture "b"],Capture "x"], -- d(ab,x) -> d(a,x)*b + a*d(b,x)
Sum [Product [Apply "d" [Capture "a",Capture "x"],Capture "b"],
Product [Capture "a",Apply "d" [Capture "b",Capture "x"]]]),
(Apply "d" [Apply "pow" [CaptureConstr "a" (Variable undefined), -- d(a^expr,x) -> a^expr * ln(a) * d(expr,x)
Capture "expr"],
Capture "x"],
Product [Apply "pow" [Capture "a",Capture "expr"],
Apply "ln" [Capture "a"],
Apply "d" [Capture "expr",Capture "x"]]),
(Apply "d" [Apply "ln" [Capture "args"],Capture "x"], -- d(ln([args]),x) -> 1/[args]*d([args],x)
Product [Reciprocal (Capture "args"),Apply "d" [Capture "args",Capture "x"]])
]
astChildMap :: AST -> (AST -> AST) -> AST
astChildMap node f = case node of
(Number _) -> node
(Variable _) -> node
(Sum args) -> Sum $ map f args
(Product args) -> Product $ map f args
(Negative n) -> Negative $ f n
(Reciprocal n) -> Reciprocal $ f n
(Apply name args) -> Apply name $ map f args
(Capture _) -> node
(CaptureTerm _) -> node
(CaptureConstr _ _) -> node
applyPatterns :: AST -> AST
applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb
in if null matches
then astChildMap node applyPatterns
else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all
in {-applyPatterns $-} replaceCaptures capdict repl
|