summaryrefslogtreecommitdiff
path: root/simplify.hs
blob: c62ff4b50953c7a1e625cda912e3e0371afb414d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
{-# LANGUAGE CPP #-}

module Simplify (simplify) where

import Data.List
import qualified Data.Map.Strict as Map

import AST
import Utility

import Debug
import PrettyPrint

tracex :: (Show a) => String -> a -> a
tracex s x = trace (s ++ ": " ++ show x) x

tracexp :: (PrettyPrint a) => String -> a -> a
tracexp s x = trace (s ++ ": " ++ prettyPrint x) x


simplify :: AST -> AST
simplify = tracex "last canonicaliseOrder" . canonicaliseOrder
           . (fixpoint $ tracex "applyPatterns    " . applyPatterns
                       . tracex "canonicaliseOrder" . canonicaliseOrder
                       . tracex "flattenSums      " . flattenSums
                       . tracex "foldNumbers      " . foldNumbers)
           . tracex "first flattenSums" . flattenSums


flattenSums :: AST -> AST
flattenSums node = case node of
    (Negative n) -> Negative $ flattenSums n
    (Reciprocal n) -> Reciprocal $ flattenSums n
    (Apply name args) -> Apply name $ map flattenSums args
    (Sum args) -> case length args of
        0 -> Number 0
        1 -> flattenSums $ args !! 0
        otherwise -> Sum $ concat $ map (listify . flattenSums) args
        where
            listify (Sum args) = args
            listify node = [node]
    (Product args) -> case length args of
        0 -> Number 1
        1 -> flattenSums $ args !! 0
        otherwise -> Product $ concat $ map (listify . flattenSums) args
        where
            listify (Product args) = args
            listify node = [node]
    _ -> node

foldNumbers :: AST -> AST
foldNumbers node = case node of
    (Negative n) -> let fn = foldNumbers n in case fn of
        (Number x) -> Number (-x)
        (Negative n2) -> n2
        (Product args) -> Product $ Number (-1) : args
        _ -> Negative $ fn
    (Reciprocal n) -> let fn = foldNumbers n in case fn of
        (Number x) -> Number (1/x)
        (Negative n) -> Negative $ Reciprocal fn
        (Reciprocal n2) -> n2
        _ -> Reciprocal $ fn
    (Apply name args) -> let fargs = map foldNumbers args
        in case name of
            "pow" -> if all astIsNumber fargs
                         then Number $ astFromNumber (fargs!!0) ** astFromNumber (fargs!!1)
                         else Apply "pow" fargs
            _ -> Apply name fargs
    (Sum args) -> Sum $ dofoldnums sum args 0
    (Product args) -> dofoldnegsToProd $ dofoldnums product args 1
    _ -> node
    where
        dofoldnums func args zerovalue =
            let foldedArgs = map foldNumbers args
                (nums,notnums) = partition astIsNumber foldedArgs
                foldvalue = func $ map (\(Number n) -> n) nums
            in case length nums of
                x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums
                otherwise -> foldedArgs
        dofoldnegsToProd args =
            let foldedArgs = map foldNumbers args
                (negs,notnegs) = partition isneg foldedArgs
                isneg (Negative _) = True
                isneg (Number n) = n < 0
                isneg _ = False
                unneg (Negative n) = n
                unneg (Number n) = Number $ abs n
                unneg n = n
                unnegged = map unneg negs ++ notnegs
            in case length negs of
                x | x < 2 -> Product args
                  | even x -> Product unnegged
                  | odd x -> Product $ Number (-1) : unnegged

canonicaliseOrder :: AST -> AST
canonicaliseOrder node = case node of
    (Number _) -> node
    (Variable _) -> node
    (Sum args) -> Sum $ sort args
    (Product args) -> Product $ sort args
    (Negative n) -> Negative $ canonicaliseOrder n
    (Reciprocal n) -> Reciprocal $ canonicaliseOrder n
    (Apply name args) -> Apply name $ map canonicaliseOrder args
    (Capture _) -> node
    (CaptureTerm _) -> node
    (CaptureConstr _ _) -> node


patterndb :: [(AST,AST)]
patterndb = [
        (Reciprocal $ Product [Reciprocal (CaptureTerm "x"),Capture "rest"],  -- 1/(1/x * [rest]) -> x * 1/[rest]
         Product [Capture "x",Reciprocal $ Capture "rest"]),

        (Product [CaptureTerm "x",Reciprocal (CaptureTerm "x"),Capture "rest"],  -- x * 1/x * [rest] -> [rest]
         Capture "rest"),

        (Product [CaptureTerm "x",  -- x * 1/(x*[rrest]) * [rest] -> [rest] * 1/[rrest]
                  Reciprocal (Product [CaptureTerm "x",Capture "rrest"]),
                  Capture "rest"],
         Product [Capture "rest",Reciprocal (Capture "rrest")]),

        (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"],  -- x + x + [rest] -> 2*x + [rest]
         Sum [Product [Number 2,Capture "x"],Capture "rest"]),

        (Sum [CaptureTerm "x",  -- x + n*x + [rest] -> (1+n)*x + [rest]
              Product [CaptureConstr "n" (Number undefined),Capture "x"],
              Capture "rest"],
         Sum [Product [Sum [Number 1,Capture "n"],
                       Capture "x"],
              Capture "rest"]),

        (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"],  -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest]
              Product [CaptureConstr "n2" (Number undefined),Capture "x"],
              Capture "rest"],

         Sum [Product [Sum [Capture "n1",Capture "n2"],
                       Capture "x"],
              Capture "rest"]),

        (Product [CaptureTerm "x",  -- x*x*[rest] -> x^2*[rest]
                  CaptureTerm "x",
                  Capture "rest"],
         Product [Apply "pow" [Capture "x",Number 2],
                  Capture "rest"]),

        (Product [CaptureTerm "x",  -- x*x^n*[rest] -> x^(n+1)*[rest]
                  Apply "pow" [CaptureTerm "x",CaptureTerm "n"],
                  Capture "rest"],
         Product [Apply "pow" [Capture "x",Sum [Capture "n",Number 1]],
                  Capture "rest"]),

        (Product [Apply "pow" [CaptureTerm "x",CaptureTerm "a"],  -- x^a*x^b*[rest] -> x^(a+b)*[rest]
                  Apply "pow" [CaptureTerm "x",CaptureTerm "b"],
                  Capture "rest"],
         Product [Apply "pow" [Capture "x",Sum [Capture "a",Capture "b"]],
                  Capture "rest"]),

        (Apply "pow" [Apply "pow" [CaptureTerm "x",CaptureTerm "n"],CaptureTerm "m"],  -- (x^n)^m -> x^(n*m)
         Apply "pow" [Capture "x",Product [Capture "n",Capture "m"]]),

        (Apply "pow" [CaptureTerm "x",Number 1],  -- x^1 -> x
         Capture "x"),

        (Product [Number 0,Capture "rest"],  -- 0*[rest] -> 0
         Number 0),

        (Product [Number 1,Capture "rest"],  -- 1*[rest] -> [rest]
         Capture "rest"),


        (Apply "d" [CaptureConstr "n" (Number undefined),Capture "x"], -- d(n,x) -> 0
         Number 0),

        (Apply "d" [CaptureConstr "x" (Variable undefined),CaptureTerm "x"],  -- d(x,x) -> 1
         Number 1),

        (Apply "d" [Apply "pow" [CaptureConstr "x" (Variable undefined),CaptureTerm "n"],  -- d(x^n,x) -> n*x^(n-1)
                    CaptureTerm "x"],
         Product [Capture "n",Apply "pow" [Capture "x",Sum [Capture "n",Number (-1)]]]),

        (Apply "d" [Sum [CaptureTerm "a",Capture "b"],CaptureTerm "x"],  -- d(a+[b],x) -> d(a,x) + d([b],x)
         Sum [Apply "d" [Capture "a",Capture "x"],Apply "d" [Capture "b",Capture "x"]]),

        (Apply "d" [Product [CaptureTerm "a",Capture "b"],Capture "x"],  -- d(ab,x) -> d(a,x)*b + a*d(b,x)
         Sum [Product [Apply "d" [Capture "a",Capture "x"],Capture "b"],
              Product [Capture "a",Apply "d" [Capture "b",Capture "x"]]]),

        (Apply "d" [Apply "pow" [CaptureConstr "a" (Variable undefined),  -- d(a^expr,x) -> a^expr * ln(a) * d(expr,x)
                                 Capture "expr"],
                    Capture "x"],
         Product [Apply "pow" [Capture "a",Capture "expr"],
                  Apply "ln" [Capture "a"],
                  Apply "d" [Capture "expr",Capture "x"]]),

        (Apply "d" [Apply "ln" [Capture "args"],Capture "x"],  -- d(ln([args]),x) -> 1/[args]*d([args],x)
         Product [Reciprocal (Capture "args"),Apply "d" [Capture "args",Capture "x"]])
    ]

astChildMap :: AST -> (AST -> AST) -> AST
astChildMap node f = case node of
    (Number _) -> node
    (Variable _) -> node
    (Sum args) -> Sum $ map f args
    (Product args) -> Product $ map f args
    (Negative n) -> Negative $ f n
    (Reciprocal n) -> Reciprocal $ f n
    (Apply name args) -> Apply name $ map f args
    (Capture _) -> node
    (CaptureTerm _) -> node
    (CaptureConstr _ _) -> node

applyPatterns :: AST -> AST
applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb
    in if null matches
        then astChildMap node applyPatterns
        else let ((capdict:_),repl) = head matches  -- TODO: don't take the first option of the first match, but use them all
             in {-applyPatterns $-} replaceCaptures capdict repl