module Simplify (simplify) where import Data.List import AST import Utility import Debug -- import PrettyPrint tracex :: (Show a) => String -> a -> a tracex s x = trace (s ++ ": " ++ show x) x -- tracexp :: (PrettyPrint a) => String -> a -> a -- tracexp s x = trace (s ++ ": " ++ prettyPrint x) x simplify :: [(AST,AST)] -> AST -> AST simplify db = tracex "last canonicaliseOrder" . canonicaliseOrder . (fixpoint $ tracex "applyPatterns " . applyPatterns db . tracex "canonicaliseOrder" . canonicaliseOrder . tracex "flattenSums " . flattenSums . tracex "foldNumbers " . foldNumbers) . tracex "first flattenSums" . flattenSums flattenSums :: AST -> AST flattenSums node = case node of (Negative n) -> Negative $ flattenSums n (Reciprocal n) -> Reciprocal $ flattenSums n (Apply name args) -> Apply name $ map flattenSums args (Sum args) -> case length args of 0 -> Number 0 1 -> flattenSums $ args !! 0 _ -> Sum $ concat $ map (listify . flattenSums) args where listify (Sum a) = a listify n = [n] (Product args) -> case length args of 0 -> Number 1 1 -> flattenSums $ args !! 0 _ -> Product $ concat $ map (listify . flattenSums) args where listify (Product a) = a listify n = [n] _ -> node foldNumbers :: AST -> AST foldNumbers node = case node of (Negative n) -> let fn = foldNumbers n in case fn of (Number x) -> Number (-x) (Negative n2) -> n2 (Product args) -> Product $ Number (-1) : args _ -> Negative $ fn (Reciprocal n) -> let fn = foldNumbers n in case fn of (Number x) -> Number (1/x) (Negative _) -> Negative $ Reciprocal fn (Reciprocal n2) -> n2 _ -> Reciprocal $ fn (Apply name args) -> let fargs = map foldNumbers args in case name of "pow" -> if all astIsNumber fargs then Number $ astFromNumber (fargs!!0) ** astFromNumber (fargs!!1) else Apply "pow" fargs _ -> Apply name fargs (Sum args) -> Sum $ dofoldnums sum args 0 (Product args) -> dofoldnegsToProd $ dofoldnums product args 1 _ -> node where dofoldnums func args zerovalue = let foldedArgs = map foldNumbers args (nums,notnums) = partition astIsNumber foldedArgs foldvalue = func $ map (\(Number n) -> n) nums in case length nums of x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums _ -> foldedArgs dofoldnegsToProd args = let foldedArgs = map foldNumbers args (negs,notnegs) = partition isneg foldedArgs isneg (Negative _) = True isneg (Number n) = n < 0 isneg _ = False unneg (Negative n) = n unneg (Number n) = Number $ abs n unneg n = n unnegged = map unneg negs ++ notnegs in case length negs of x | x < 2 -> Product args | even x -> Product unnegged | otherwise -> Product $ Number (-1) : unnegged canonicaliseOrder :: AST -> AST canonicaliseOrder node = case node of (Number _) -> node (Variable _) -> node (Sum args) -> Sum $ sort args (Product args) -> Product $ sort args (Negative n) -> Negative $ canonicaliseOrder n (Reciprocal n) -> Reciprocal $ canonicaliseOrder n (Apply name args) -> Apply name $ map canonicaliseOrder args (Capture _) -> node (CaptureTerm _) -> node (CaptureConstr _ _) -> node astChildMap :: AST -> (AST -> AST) -> AST astChildMap node f = case node of (Number _) -> node (Variable _) -> node (Sum args) -> Sum $ map f args (Product args) -> Product $ map f args (Negative n) -> Negative $ f n (Reciprocal n) -> Reciprocal $ f n (Apply name args) -> Apply name $ map f args (Capture _) -> node (CaptureTerm _) -> node (CaptureConstr _ _) -> node applyPatterns :: [(AST,AST)] -> AST -> AST applyPatterns db node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) db in if null matches then astChildMap node (applyPatterns db) else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all in {-applyPatterns $-} replaceCaptures capdict repl