summaryrefslogtreecommitdiff
path: root/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Simplify.hs')
-rw-r--r--Simplify.hs125
1 files changed, 125 insertions, 0 deletions
diff --git a/Simplify.hs b/Simplify.hs
new file mode 100644
index 0000000..9ae29f1
--- /dev/null
+++ b/Simplify.hs
@@ -0,0 +1,125 @@
+module Simplify (simplify) where
+
+import Data.List
+import qualified Data.Map.Strict as Map
+
+import AST
+import Utility
+
+import Debug
+import PrettyPrint
+
+tracex :: (Show a) => String -> a -> a
+tracex s x = trace (s ++ ": " ++ show x) x
+
+tracexp :: (PrettyPrint a) => String -> a -> a
+tracexp s x = trace (s ++ ": " ++ prettyPrint x) x
+
+
+simplify :: [(AST,AST)] -> AST -> AST
+simplify db = tracex "last canonicaliseOrder" . canonicaliseOrder
+ . (fixpoint $ tracex "applyPatterns " . applyPatterns db
+ . tracex "canonicaliseOrder" . canonicaliseOrder
+ . tracex "flattenSums " . flattenSums
+ . tracex "foldNumbers " . foldNumbers)
+ . tracex "first flattenSums" . flattenSums
+
+
+flattenSums :: AST -> AST
+flattenSums node = case node of
+ (Negative n) -> Negative $ flattenSums n
+ (Reciprocal n) -> Reciprocal $ flattenSums n
+ (Apply name args) -> Apply name $ map flattenSums args
+ (Sum args) -> case length args of
+ 0 -> Number 0
+ 1 -> flattenSums $ args !! 0
+ otherwise -> Sum $ concat $ map (listify . flattenSums) args
+ where
+ listify (Sum args) = args
+ listify node = [node]
+ (Product args) -> case length args of
+ 0 -> Number 1
+ 1 -> flattenSums $ args !! 0
+ otherwise -> Product $ concat $ map (listify . flattenSums) args
+ where
+ listify (Product args) = args
+ listify node = [node]
+ _ -> node
+
+foldNumbers :: AST -> AST
+foldNumbers node = case node of
+ (Negative n) -> let fn = foldNumbers n in case fn of
+ (Number x) -> Number (-x)
+ (Negative n2) -> n2
+ (Product args) -> Product $ Number (-1) : args
+ _ -> Negative $ fn
+ (Reciprocal n) -> let fn = foldNumbers n in case fn of
+ (Number x) -> Number (1/x)
+ (Negative n) -> Negative $ Reciprocal fn
+ (Reciprocal n2) -> n2
+ _ -> Reciprocal $ fn
+ (Apply name args) -> let fargs = map foldNumbers args
+ in case name of
+ "pow" -> if all astIsNumber fargs
+ then Number $ astFromNumber (fargs!!0) ** astFromNumber (fargs!!1)
+ else Apply "pow" fargs
+ _ -> Apply name fargs
+ (Sum args) -> Sum $ dofoldnums sum args 0
+ (Product args) -> dofoldnegsToProd $ dofoldnums product args 1
+ _ -> node
+ where
+ dofoldnums func args zerovalue =
+ let foldedArgs = map foldNumbers args
+ (nums,notnums) = partition astIsNumber foldedArgs
+ foldvalue = func $ map (\(Number n) -> n) nums
+ in case length nums of
+ x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums
+ otherwise -> foldedArgs
+ dofoldnegsToProd args =
+ let foldedArgs = map foldNumbers args
+ (negs,notnegs) = partition isneg foldedArgs
+ isneg (Negative _) = True
+ isneg (Number n) = n < 0
+ isneg _ = False
+ unneg (Negative n) = n
+ unneg (Number n) = Number $ abs n
+ unneg n = n
+ unnegged = map unneg negs ++ notnegs
+ in case length negs of
+ x | x < 2 -> Product args
+ | even x -> Product unnegged
+ | odd x -> Product $ Number (-1) : unnegged
+
+canonicaliseOrder :: AST -> AST
+canonicaliseOrder node = case node of
+ (Number _) -> node
+ (Variable _) -> node
+ (Sum args) -> Sum $ sort args
+ (Product args) -> Product $ sort args
+ (Negative n) -> Negative $ canonicaliseOrder n
+ (Reciprocal n) -> Reciprocal $ canonicaliseOrder n
+ (Apply name args) -> Apply name $ map canonicaliseOrder args
+ (Capture _) -> node
+ (CaptureTerm _) -> node
+ (CaptureConstr _ _) -> node
+
+
+astChildMap :: AST -> (AST -> AST) -> AST
+astChildMap node f = case node of
+ (Number _) -> node
+ (Variable _) -> node
+ (Sum args) -> Sum $ map f args
+ (Product args) -> Product $ map f args
+ (Negative n) -> Negative $ f n
+ (Reciprocal n) -> Reciprocal $ f n
+ (Apply name args) -> Apply name $ map f args
+ (Capture _) -> node
+ (CaptureTerm _) -> node
+ (CaptureConstr _ _) -> node
+
+applyPatterns :: [(AST,AST)] -> AST -> AST
+applyPatterns db node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) db
+ in if null matches
+ then astChildMap node (applyPatterns db)
+ else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all
+ in {-applyPatterns $-} replaceCaptures capdict repl