summaryrefslogtreecommitdiff
path: root/Simplify.hs
blob: 9ae29f12c5b7484e77cb76a0c120d88d13283996 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
module Simplify (simplify) where

import Data.List
import qualified Data.Map.Strict as Map

import AST
import Utility

import Debug
import PrettyPrint

tracex :: (Show a) => String -> a -> a
tracex s x = trace (s ++ ": " ++ show x) x

tracexp :: (PrettyPrint a) => String -> a -> a
tracexp s x = trace (s ++ ": " ++ prettyPrint x) x


simplify :: [(AST,AST)] -> AST -> AST
simplify db = tracex "last canonicaliseOrder" . canonicaliseOrder
              . (fixpoint $ tracex "applyPatterns    " . applyPatterns db
                          . tracex "canonicaliseOrder" . canonicaliseOrder
                          . tracex "flattenSums      " . flattenSums
                          . tracex "foldNumbers      " . foldNumbers)
              . tracex "first flattenSums" . flattenSums


flattenSums :: AST -> AST
flattenSums node = case node of
    (Negative n) -> Negative $ flattenSums n
    (Reciprocal n) -> Reciprocal $ flattenSums n
    (Apply name args) -> Apply name $ map flattenSums args
    (Sum args) -> case length args of
        0 -> Number 0
        1 -> flattenSums $ args !! 0
        otherwise -> Sum $ concat $ map (listify . flattenSums) args
        where
            listify (Sum args) = args
            listify node = [node]
    (Product args) -> case length args of
        0 -> Number 1
        1 -> flattenSums $ args !! 0
        otherwise -> Product $ concat $ map (listify . flattenSums) args
        where
            listify (Product args) = args
            listify node = [node]
    _ -> node

foldNumbers :: AST -> AST
foldNumbers node = case node of
    (Negative n) -> let fn = foldNumbers n in case fn of
        (Number x) -> Number (-x)
        (Negative n2) -> n2
        (Product args) -> Product $ Number (-1) : args
        _ -> Negative $ fn
    (Reciprocal n) -> let fn = foldNumbers n in case fn of
        (Number x) -> Number (1/x)
        (Negative n) -> Negative $ Reciprocal fn
        (Reciprocal n2) -> n2
        _ -> Reciprocal $ fn
    (Apply name args) -> let fargs = map foldNumbers args
        in case name of
            "pow" -> if all astIsNumber fargs
                         then Number $ astFromNumber (fargs!!0) ** astFromNumber (fargs!!1)
                         else Apply "pow" fargs
            _ -> Apply name fargs
    (Sum args) -> Sum $ dofoldnums sum args 0
    (Product args) -> dofoldnegsToProd $ dofoldnums product args 1
    _ -> node
    where
        dofoldnums func args zerovalue =
            let foldedArgs = map foldNumbers args
                (nums,notnums) = partition astIsNumber foldedArgs
                foldvalue = func $ map (\(Number n) -> n) nums
            in case length nums of
                x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums
                otherwise -> foldedArgs
        dofoldnegsToProd args =
            let foldedArgs = map foldNumbers args
                (negs,notnegs) = partition isneg foldedArgs
                isneg (Negative _) = True
                isneg (Number n) = n < 0
                isneg _ = False
                unneg (Negative n) = n
                unneg (Number n) = Number $ abs n
                unneg n = n
                unnegged = map unneg negs ++ notnegs
            in case length negs of
                x | x < 2 -> Product args
                  | even x -> Product unnegged
                  | odd x -> Product $ Number (-1) : unnegged

canonicaliseOrder :: AST -> AST
canonicaliseOrder node = case node of
    (Number _) -> node
    (Variable _) -> node
    (Sum args) -> Sum $ sort args
    (Product args) -> Product $ sort args
    (Negative n) -> Negative $ canonicaliseOrder n
    (Reciprocal n) -> Reciprocal $ canonicaliseOrder n
    (Apply name args) -> Apply name $ map canonicaliseOrder args
    (Capture _) -> node
    (CaptureTerm _) -> node
    (CaptureConstr _ _) -> node


astChildMap :: AST -> (AST -> AST) -> AST
astChildMap node f = case node of
    (Number _) -> node
    (Variable _) -> node
    (Sum args) -> Sum $ map f args
    (Product args) -> Product $ map f args
    (Negative n) -> Negative $ f n
    (Reciprocal n) -> Reciprocal $ f n
    (Apply name args) -> Apply name $ map f args
    (Capture _) -> node
    (CaptureTerm _) -> node
    (CaptureConstr _ _) -> node

applyPatterns :: [(AST,AST)] -> AST -> AST
applyPatterns db node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) db
    in if null matches
        then astChildMap node (applyPatterns db)
        else let ((capdict:_),repl) = head matches  -- TODO: don't take the first option of the first match, but use them all
             in {-applyPatterns $-} replaceCaptures capdict repl