summaryrefslogtreecommitdiff
path: root/simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'simplify.hs')
-rw-r--r--simplify.hs172
1 files changed, 172 insertions, 0 deletions
diff --git a/simplify.hs b/simplify.hs
new file mode 100644
index 0000000..cb3e5af
--- /dev/null
+++ b/simplify.hs
@@ -0,0 +1,172 @@
+{-# LANGUAGE CPP #-}
+
+module Simplify (simplify) where
+
+import Data.List
+import qualified Data.Map.Strict as Map
+
+import AST
+import Utility
+
+import Debug
+import PrettyPrint
+
+tracex :: (Show a) => String -> a -> a
+tracex s x = trace (s ++ ": " ++ show x) x
+
+tracexp :: (PrettyPrint a) => String -> a -> a
+tracexp s x = trace (s ++ ": " ++ prettyPrint x) x
+
+
+simplify :: AST -> AST
+simplify = tracex "last canonicaliseOrder" . canonicaliseOrder
+ . (fixpoint $ tracex "applyPatterns " . applyPatterns
+ . tracex "flattenSums " . flattenSums
+ -- . tracex "collectLikeTerms " . collectLikeTerms
+ . tracex "canonicaliseOrder" . canonicaliseOrder
+ . tracex "foldNumbers " . foldNumbers)
+ . tracex "first flattenSums" . flattenSums
+
+
+flattenSums :: AST -> AST
+flattenSums node = case node of
+ (Negative n) -> Negative $ flattenSums n
+ (Reciprocal n) -> Reciprocal $ flattenSums n
+ (Apply name args) -> Apply name $ map flattenSums args
+ (Sum args) -> case length args of
+ 0 -> Number 0
+ 1 -> flattenSums $ args !! 0
+ otherwise -> Sum $ concat $ map (listify . flattenSums) args
+ where
+ listify (Sum args) = args
+ listify node = [node]
+ (Product args) -> case length args of
+ 0 -> Number 1
+ 1 -> flattenSums $ args !! 0
+ otherwise -> Product $ concat $ map (listify . flattenSums) args
+ where
+ listify (Product args) = args
+ listify node = [node]
+ _ -> node
+
+foldNumbers :: AST -> AST
+foldNumbers node = case node of
+ (Negative n) -> let fn = foldNumbers n in case fn of
+ (Number x) -> Number (-x)
+ (Negative n2) -> n2
+ (Product args) -> Product $ Number (-1) : args
+ _ -> Negative $ fn
+ (Reciprocal n) -> let fn = foldNumbers n in case fn of
+ (Number x) -> Number (1/x)
+ (Negative n) -> Negative $ Reciprocal fn
+ (Reciprocal n2) -> n2
+ _ -> Reciprocal $ fn
+ (Apply name args) -> Apply name $ map foldNumbers args
+ (Sum args) -> Sum $ dofoldnums sum args 0
+ (Product args) -> dofoldnegsToProd $ dofoldnums product args 1
+ _ -> node
+ where
+ dofoldnums func args zerovalue =
+ let foldedArgs = map foldNumbers args
+ (nums,notnums) = partition astIsNumber foldedArgs
+ foldvalue = func $ map (\(Number n) -> n) nums
+ in case length nums of
+ x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums
+ otherwise -> foldedArgs
+ dofoldnegsToProd args =
+ let foldedArgs = map foldNumbers args
+ (negs,notnegs) = partition isneg foldedArgs
+ isneg (Negative _) = True
+ isneg (Number n) = n < 0
+ isneg _ = False
+ unneg (Negative n) = n
+ unneg (Number n) = Number $ abs n
+ unneg n = n
+ unnegged = map unneg negs ++ notnegs
+ in case length negs of
+ x | x < 2 -> Product args
+ | even x -> Product unnegged
+ | odd x -> Product $ Number (-1) : unnegged
+
+-- collectLikeTerms :: AST -> AST
+-- collectLikeTerms node = case node of
+-- (Reciprocal n) -> Apply "pow" [n,Number $ -1]
+-- (Product args) ->
+-- let ispow (Apply "pow" _) = True
+-- ispow _ = False
+-- (pows,nopows) = partition ispow $ map collectLikeTerms args
+-- groups = groupBy (\(Apply _ [x,_]) (Apply _ [y,_]) -> x == y) pows
+-- baseof (Apply _ [x,_]) = x
+-- expof (Apply _ [_,x]) = x
+-- collectGroup l = Apply "pow" [baseof (l!!0),Sum $ map expof l]
+-- in Product $ map collectGroup groups ++ nopows
+-- (Sum args) ->
+-- let isnumterm (Product (Number _:_)) = True
+-- isnumterm _ = False
+-- (numterms,nonumterms) = partition isnumterm $ map collectLikeTerms args
+-- groups = groupBy (\(Product (Number _:xs)) (Product (Number _:ys))
+-- -> astMatchSimple (Product xs) (Product ys))
+-- numterms
+-- numof (Product (n:_)) = n
+-- restof (Product (_:rest)) = rest
+-- collectGroup l =
+-- if length l == 1
+-- then l!!0
+-- else Product $ Sum (map numof l) : restof (l!!0)
+-- in Sum $ map collectGroup groups ++ nonumterms
+-- _ -> node
+
+canonicaliseOrder :: AST -> AST
+canonicaliseOrder node = case node of
+ (Number _) -> node
+ (Variable _) -> node
+ (Sum args) -> Sum $ sort args
+ (Product args) -> Product $ sort args
+ (Negative n) -> Negative $ canonicaliseOrder n
+ (Reciprocal n) -> Reciprocal $ canonicaliseOrder n
+ (Apply name args) -> Apply name $ map canonicaliseOrder args
+ (Capture _) -> node
+ (CaptureTerm _) -> node
+ (CaptureConstr _ _) -> node
+
+
+patterndb :: [(AST,AST)]
+patterndb = [
+ (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"], -- x + x + [rest] -> 2*x + [rest]
+ Sum [Product [Number 2,Capture "x"],Capture "rest"]),
+
+ (Sum [CaptureTerm "x", -- x + n*x + [rest] -> (1+n)*x + [rest]
+ Product [CaptureConstr "n" (Number undefined),Capture "x"],
+ Capture "rest"],
+ Sum [Product [Sum [Number 1,Capture "n"],
+ Capture "x"],
+ Capture "rest"]),
+
+ (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"], -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest]
+ Product [CaptureConstr "n2" (Number undefined),Capture "x"],
+ Capture "rest"],
+
+ Sum [Product [Sum [Capture "n1",Capture "n2"],
+ Capture "x"],
+ Capture "rest"])
+ ]
+
+astChildMap :: AST -> (AST -> AST) -> AST
+astChildMap node f = case node of
+ (Number _) -> node
+ (Variable _) -> node
+ (Sum args) -> Sum $ map f args
+ (Product args) -> Product $ map f args
+ (Negative n) -> Negative $ f n
+ (Reciprocal n) -> Reciprocal $ f n
+ (Apply name args) -> Apply name $ map f args
+ (Capture _) -> node
+ (CaptureTerm _) -> node
+ (CaptureConstr _ _) -> node
+
+applyPatterns :: AST -> AST
+applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb
+ in if null matches
+ then astChildMap node applyPatterns
+ else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all
+ in replaceCaptures capdict repl