{-# LANGUAGE CPP #-} module Simplify (simplify) where import Data.List import qualified Data.Map.Strict as Map import AST import Utility import Debug import PrettyPrint tracex :: (Show a) => String -> a -> a tracex s x = trace (s ++ ": " ++ show x) x tracexp :: (PrettyPrint a) => String -> a -> a tracexp s x = trace (s ++ ": " ++ prettyPrint x) x simplify :: AST -> AST simplify = tracex "last canonicaliseOrder" . canonicaliseOrder . (fixpoint $ tracex "applyPatterns " . applyPatterns . tracex "flattenSums " . flattenSums -- . tracex "collectLikeTerms " . collectLikeTerms . tracex "canonicaliseOrder" . canonicaliseOrder . tracex "foldNumbers " . foldNumbers) . tracex "first flattenSums" . flattenSums flattenSums :: AST -> AST flattenSums node = case node of (Negative n) -> Negative $ flattenSums n (Reciprocal n) -> Reciprocal $ flattenSums n (Apply name args) -> Apply name $ map flattenSums args (Sum args) -> case length args of 0 -> Number 0 1 -> flattenSums $ args !! 0 otherwise -> Sum $ concat $ map (listify . flattenSums) args where listify (Sum args) = args listify node = [node] (Product args) -> case length args of 0 -> Number 1 1 -> flattenSums $ args !! 0 otherwise -> Product $ concat $ map (listify . flattenSums) args where listify (Product args) = args listify node = [node] _ -> node foldNumbers :: AST -> AST foldNumbers node = case node of (Negative n) -> let fn = foldNumbers n in case fn of (Number x) -> Number (-x) (Negative n2) -> n2 (Product args) -> Product $ Number (-1) : args _ -> Negative $ fn (Reciprocal n) -> let fn = foldNumbers n in case fn of (Number x) -> Number (1/x) (Negative n) -> Negative $ Reciprocal fn (Reciprocal n2) -> n2 _ -> Reciprocal $ fn (Apply name args) -> let fargs = map foldNumbers args -- Apply name $ map foldNumbers args in case name of "pow" -> if all astIsNumber fargs then Number $ astFromNumber (fargs!!0) ** astFromNumber (fargs!!1) else Apply "pow" fargs _ -> Apply name fargs (Sum args) -> Sum $ dofoldnums sum args 0 (Product args) -> dofoldnegsToProd $ dofoldnums product args 1 _ -> node where dofoldnums func args zerovalue = let foldedArgs = map foldNumbers args (nums,notnums) = partition astIsNumber foldedArgs foldvalue = func $ map (\(Number n) -> n) nums in case length nums of x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums otherwise -> foldedArgs dofoldnegsToProd args = let foldedArgs = map foldNumbers args (negs,notnegs) = partition isneg foldedArgs isneg (Negative _) = True isneg (Number n) = n < 0 isneg _ = False unneg (Negative n) = n unneg (Number n) = Number $ abs n unneg n = n unnegged = map unneg negs ++ notnegs in case length negs of x | x < 2 -> Product args | even x -> Product unnegged | odd x -> Product $ Number (-1) : unnegged -- collectLikeTerms :: AST -> AST -- collectLikeTerms node = case node of -- (Reciprocal n) -> Apply "pow" [n,Number $ -1] -- (Product args) -> -- let ispow (Apply "pow" _) = True -- ispow _ = False -- (pows,nopows) = partition ispow $ map collectLikeTerms args -- groups = groupBy (\(Apply _ [x,_]) (Apply _ [y,_]) -> x == y) pows -- baseof (Apply _ [x,_]) = x -- expof (Apply _ [_,x]) = x -- collectGroup l = Apply "pow" [baseof (l!!0),Sum $ map expof l] -- in Product $ map collectGroup groups ++ nopows -- (Sum args) -> -- let isnumterm (Product (Number _:_)) = True -- isnumterm _ = False -- (numterms,nonumterms) = partition isnumterm $ map collectLikeTerms args -- groups = groupBy (\(Product (Number _:xs)) (Product (Number _:ys)) -- -> astMatchSimple (Product xs) (Product ys)) -- numterms -- numof (Product (n:_)) = n -- restof (Product (_:rest)) = rest -- collectGroup l = -- if length l == 1 -- then l!!0 -- else Product $ Sum (map numof l) : restof (l!!0) -- in Sum $ map collectGroup groups ++ nonumterms -- _ -> node canonicaliseOrder :: AST -> AST canonicaliseOrder node = case node of (Number _) -> node (Variable _) -> node (Sum args) -> Sum $ sort args (Product args) -> Product $ sort args (Negative n) -> Negative $ canonicaliseOrder n (Reciprocal n) -> Reciprocal $ canonicaliseOrder n (Apply name args) -> Apply name $ map canonicaliseOrder args (Capture _) -> node (CaptureTerm _) -> node (CaptureConstr _ _) -> node patterndb :: [(AST,AST)] patterndb = [ (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"], -- x + x + [rest] -> 2*x + [rest] Sum [Product [Number 2,Capture "x"],Capture "rest"]), (Sum [CaptureTerm "x", -- x + n*x + [rest] -> (1+n)*x + [rest] Product [CaptureConstr "n" (Number undefined),Capture "x"], Capture "rest"], Sum [Product [Sum [Number 1,Capture "n"], Capture "x"], Capture "rest"]), (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"], -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest] Product [CaptureConstr "n2" (Number undefined),Capture "x"], Capture "rest"], Sum [Product [Sum [Capture "n1",Capture "n2"], Capture "x"], Capture "rest"]), (Product [CaptureTerm "x", -- x*x*[rest] -> x^2*[rest] CaptureTerm "x", Capture "rest"], Product [Apply "pow" [Capture "x",Number 2], Capture "rest"]), (Product [CaptureTerm "x", -- x*x^n*[rest] -> x^(n+1)*[rest] Apply "pow" [CaptureTerm "x",CaptureTerm "n"], Capture "rest"], Product [Apply "pow" [Capture "x",Sum [Capture "n",Number 1]], Capture "rest"]), (Product [Apply "pow" [CaptureTerm "x",CaptureTerm "a"], -- x^a*x^b*[rest] -> x^(a+b)*[rest] Apply "pow" [CaptureTerm "x",CaptureTerm "b"], Capture "rest"], Product [Apply "pow" [Capture "x",Sum [Capture "a",Capture "b"]], Capture "rest"]), (Apply "pow" [Apply "pow" [CaptureTerm "x",CaptureTerm "n"],CaptureTerm "m"], -- (x^n)^m -> x^(n*m) Apply "pow" [Capture "x",Product [Capture "n",Capture "m"]]), (Apply "pow" [CaptureTerm "x",Number 1], -- x^1 -> x Capture "x"), (Apply "d" [CaptureConstr "x" (Variable undefined),CaptureTerm "x"], -- d(x,x) -> 1 Number 1), (Apply "d" [Apply "pow" [CaptureConstr "x" (Variable undefined),CaptureTerm "n"], -- d(x^n,x) -> n*x^(n-1) CaptureTerm "x"], Product [Capture "n",Apply "pow" [Capture "x",Sum [Capture "n",Number (-1)]]]), (Apply "d" [Sum [CaptureTerm "a",Capture "b"],CaptureTerm "x"], -- d(a+[b],x) -> d(a,x) + d([b],x) Sum [Apply "d" [Capture "a",Capture "x"],Apply "d" [Capture "b",Capture "x"]]) ] astChildMap :: AST -> (AST -> AST) -> AST astChildMap node f = case node of (Number _) -> node (Variable _) -> node (Sum args) -> Sum $ map f args (Product args) -> Product $ map f args (Negative n) -> Negative $ f n (Reciprocal n) -> Reciprocal $ f n (Apply name args) -> Apply name $ map f args (Capture _) -> node (CaptureTerm _) -> node (CaptureConstr _ _) -> node applyPatterns :: AST -> AST applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb in if null matches then astChildMap node applyPatterns else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all in {-applyPatterns $-} replaceCaptures capdict repl