summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortomsmeding <tom.smeding@gmail.com>2016-06-16 23:24:47 +0200
committertomsmeding <tom.smeding@gmail.com>2016-06-16 23:24:47 +0200
commitdd3db844dd49451f28d044cd1d2fd71430d73686 (patch)
treecd76d7ad6efbf2d2e4760695d39cb48bb479a936
Initial
-rw-r--r--.gitignore4
-rw-r--r--Makefile16
-rw-r--r--ast.hs283
-rw-r--r--debug.hs23
-rw-r--r--main.hs31
-rw-r--r--parser.hs236
-rw-r--r--prettyprint.hs11
-rw-r--r--simplify.hs172
-rw-r--r--utility.hs20
9 files changed, 796 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2c67b73
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.o
+*.hi
+*.hs[0-9]
+main
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..6d4ca8f
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,16 @@
+BIN = main
+
+hs_files = $(wildcard *.hs)
+
+.PHONY: all clean remake
+
+all: $(BIN)
+
+clean:
+ rm -f *.hi *.o $(BIN)
+
+remake: clean all
+
+
+$(BIN): $(hs_files)
+ ghc -O3 -o $(BIN) $^
diff --git a/ast.hs b/ast.hs
new file mode 100644
index 0000000..1e267c1
--- /dev/null
+++ b/ast.hs
@@ -0,0 +1,283 @@
+{-# LANGUAGE TupleSections, BangPatterns, DeriveDataTypeable #-}
+
+module AST where
+
+import qualified Data.Map.Strict as Map
+import Data.List
+import Data.Data
+import Data.Typeable
+import Control.DeepSeq
+
+import PrettyPrint
+import Debug
+
+
+data AST = Number Double
+ | Variable String
+ | Sum [AST]
+ | Product [AST]
+ | Negative AST
+ | Reciprocal AST
+ | Apply String [AST]
+ -- The following only in patterns:
+ | Capture String
+ | CaptureTerm String
+ | CaptureConstr String AST -- AST argument only for constructor; only WHNF
+ deriving (Eq,Typeable,Data)
+
+instance PrettyPrint AST where
+ prettyPrint (Number n) = show n
+
+ prettyPrint (Variable n) = n
+
+ prettyPrint (Sum []) = "(+)"
+ prettyPrint (Sum args) = intercalate " + " $ map prettyPrint args
+
+ prettyPrint (Product []) = "(*)"
+ prettyPrint (Product args) = intercalate "*" $ map gopp args
+ where gopp s@(Sum _) = '(' : prettyPrint s ++ ")"
+ gopp n = prettyPrint n
+
+ prettyPrint (Negative n) = '-' : case n of
+ s@(Sum _) -> '(' : prettyPrint s ++ ")"
+ n -> prettyPrint n
+
+ prettyPrint (Reciprocal n) = "1/" ++ case n of
+ s@(Sum _) -> '(' : prettyPrint s ++ ")"
+ s@(Product _) -> '(' : prettyPrint s ++ ")"
+ s@(Reciprocal _) -> '(' : prettyPrint s ++ ")"
+ n -> prettyPrint n
+
+ prettyPrint (Apply name args) = name ++ "(" ++ intercalate "," (map prettyPrint args) ++ ")"
+
+ prettyPrint (Capture name) = '[' : name ++ "]"
+
+ prettyPrint (CaptureTerm name) = '[' : '[' : name ++ "]]"
+
+ prettyPrint (CaptureConstr name c) = '[' : name ++ ":" ++ showConstr (toConstr c) ++ "]"
+
+instance Show AST where
+ show = prettyPrint
+
+instance NFData AST where
+ rnf (Number !_) = ()
+ rnf (Variable !_) = ()
+ rnf (Sum l) = seq (length $ map rnf l) ()
+ rnf (Product l) = seq (length $ map rnf l) ()
+ rnf (Negative n) = rnf n
+ rnf (Reciprocal n) = rnf n
+ rnf (Apply !_ l) = seq (length $ map rnf l) ()
+ rnf (Capture !_) = ()
+ rnf (CaptureTerm !_) = ()
+ rnf (CaptureConstr !_ !_) = () -- explicitly not deepseq'ing the ast node
+
+instance Ord AST where
+ compare (Number _) (Number _) = EQ
+ compare (Variable _) (Variable _) = EQ
+ compare (Sum _) (Sum _) = EQ
+ compare (Product _) (Product _) = EQ
+ compare (Negative _) (Negative _) = EQ
+ compare (Reciprocal _) (Reciprocal _) = EQ
+ compare (Apply _ _) (Apply _ _) = EQ
+ compare (Capture _) (Capture _) = EQ
+ compare (CaptureTerm _) (CaptureTerm _) = EQ
+ compare (CaptureConstr _ _) (CaptureConstr _ _) = EQ
+
+ compare (Capture _) _ = LT -- Unbounded captures first for efficient
+ compare _ (Capture _) = GT -- extraction with span isCapture
+ compare (Number _) _ = LT
+ compare _ (Number _) = GT
+ compare (Variable _) _ = LT
+ compare _ (Variable _) = GT
+ compare (Sum _) _ = LT
+ compare _ (Sum _) = GT
+ compare (Product _) _ = LT
+ compare _ (Product _) = GT
+ compare (Negative _) _ = LT
+ compare _ (Negative _) = GT
+ compare (Reciprocal _) _ = LT
+ compare _ (Reciprocal _) = GT
+ compare (Apply _ _) _ = LT
+ compare _ (Apply _ _) = GT
+ compare (CaptureTerm _) _ = LT
+ compare _ (CaptureTerm _) = GT
+ -- compare (CaptureConstr _ _) _ = LT
+ -- compare _ (CaptureConstr _ _) = GT
+
+
+astIsNumber :: AST -> Bool
+astIsNumber (Number _) = True
+astIsNumber _ = False
+
+astIsCapture :: AST -> Bool
+astIsCapture (Capture _) = True
+astIsCapture _ = False
+
+
+astMatchSimple :: AST -> AST -> Bool
+astMatchSimple pat sub = let res = {-(\x -> trace (" !! RESULT: " ++ show x ++ " !! ") x) $-} astMatch pat sub
+ in if null res
+ then False
+ else any Map.null res
+
+
+astMatch :: AST -- pattern
+ -> AST -- subject
+ -> [Map.Map String AST] -- list of possible capture assignments
+astMatch pat sub = assertS "No captures in astMatch subject" (not $ hasCaptures sub) $
+ case pat of
+ Number x -> case sub of
+ Number y | x == y -> [Map.empty]
+ _ -> []
+
+ Variable name -> case sub of
+ Variable name2 | name == name2 -> [Map.empty]
+ _ -> []
+
+ Sum [term] -> case sub of
+ Sum l2 -> matchList Sum [term] l2
+ s -> astMatch term s
+
+ Sum l -> case sub of
+ Sum l2 -> matchList Sum l l2
+ _ -> []
+
+ Product [term] -> case sub of
+ Product l2 -> matchList Product [term] l2
+ s -> astMatch term s
+
+ Product l -> case sub of
+ Product l2 -> matchList Product l l2
+ _ -> []
+
+ Negative n -> case sub of
+ Negative n2 -> astMatch n n2
+ _ -> []
+
+ Reciprocal n -> case sub of
+ Reciprocal n2 -> astMatch n n2
+ _ -> []
+
+ Apply name l -> case sub of
+ Apply name l2 -> matchOrderedList l l2
+ _ -> []
+
+ Capture name -> [Map.singleton name sub]
+
+ CaptureTerm name -> [Map.singleton name sub]
+
+ CaptureConstr name constr ->
+ if toConstr sub == toConstr constr
+ then [Map.singleton name sub]
+ else []
+
+
+matchList :: ([AST] -> AST) -- AST constructor for this list (for insertion in capture)
+ -> [AST] -- unordered patterns
+ -> [AST] -- unordered subjects
+ -> [Map.Map String AST] -- list of possible capture assignments
+matchList constr pats subs =
+ let ordered = sort pats
+ (captures,nocaps) = span astIsCapture ordered
+ in assertS "At most one capture in sum/product" (length captures <= 1) $ case captures of
+ [] -> matchListDBG Nothing nocaps subs
+ [c] -> matchListDBG (Just c) nocaps subs
+ where matchList' :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST]
+ matchList' Nothing [] [] = [Map.empty]
+ matchList' Nothing [] _ = []
+ matchList' (Just (Capture name)) [] subs = [Map.singleton name $ constr subs]
+ matchList' (Just node) [] subs = astMatch node (constr subs)
+ matchList' mcap (pat:pats) subs =
+ let firstmatches = concat $ mapDel (\s other -> map (,other) $ astMatch pat s) subs
+ processed = concat
+ $ map (\(ass,rest) ->
+ let replpats = map (replaceCaptures ass) pats
+ replmcap = fmap (replaceCaptures ass) mcap
+ in map (Map.union ass) $ matchListDBG replmcap replpats rest)
+ firstmatches
+ in {-trace ("firstmatches = "++show firstmatches) $ trace ("processed = "++show processed) $-} processed
+
+ matchListDBG :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST]
+ matchListDBG mcap pats subs = {-force $ trace ("\n<< "++show (mcap,pats,subs)++" >>\n")
+ $-} matchList' mcap pats subs
+
+
+matchOrderedList :: [AST] -- ordered patterns
+ -> [AST] -- ordered subjects
+ -> [Map.Map String AST] -- list of possible capture assignments
+matchOrderedList [] [] = [Map.empty]
+matchOrderedList [] _ = []
+matchOrderedList _ [] = []
+matchOrderedList (pat:pats) (sub:subs) =
+ let opts = astMatch pat sub
+ newpatsopts = [(map (replaceCaptures opt) pats,opt) | opt <- opts]
+ -- ^ list of possible refined versions of the (rest of the) pattern list
+ in {-trace (show (pat:pats) ++ " -- " ++ show (sub:subs)) $ traceShow opts $-}
+ concat $ map (\(newpats,opt) -> map (Map.union opt)
+ $ matchOrderedList newpats subs) newpatsopts
+
+
+replaceCaptures :: Map.Map String AST -> AST -> AST
+replaceCaptures mp n = case n of
+ Number _ -> n
+ Variable _ -> n
+ Sum l -> Sum $ map (replaceCaptures mp) l
+ Product l -> Product $ map (replaceCaptures mp) l
+ Negative n2 -> Negative $ replaceCaptures mp n2
+ Reciprocal n2 -> Reciprocal $ replaceCaptures mp n2
+ Apply name n2 -> Apply name $ map (replaceCaptures mp) n2
+ Capture name -> maybe n id $ Map.lookup name mp
+ CaptureTerm name -> maybe n id $ Map.lookup name mp
+ CaptureConstr name c -> maybe n id $ Map.lookup name mp
+
+
+hasCaptures :: AST -> Bool
+hasCaptures n = case n of
+ Number _ -> False
+ Variable _ -> False
+ Sum l -> any id [hasCaptures m | m <- l]
+ Product l -> any id [hasCaptures m | m <- l]
+ Negative m -> hasCaptures m
+ Reciprocal m -> hasCaptures m
+ Apply _ l -> any id [hasCaptures m | m <- l]
+ Capture _ -> True
+ CaptureTerm _ -> True
+ CaptureConstr _ _ -> True
+
+
+assert :: Bool -> a -> a
+assert = assertS "(no reason)"
+
+assertS :: String -> Bool -> a -> a
+assertS _ True = id
+assertS s False = error $ "Condition not satisfied in assert: " ++ s
+
+
+mapDel :: (a -> [a] -> b) -> [a] -> [b]
+mapDel _ [] = []
+mapDel f l =
+ let splits = zip l
+ $ map (\(a,b:bs) -> a++bs)
+ $ iterate (\(a,b:bs) -> (a++[b],bs)) ([],l)
+ in map (uncurry f) splits
+
+
+-- some testing things
+--pat = Sum [Number 1,Capture "x",Negative $ Capture "x"]
+--sub = Sum [Number 4,Variable "a",Number 1,Negative $ Sum [Variable "a",Number 4]]
+
+--pat = Sum [Negative $ Capture "x"]
+--sub = Sum [Negative $ Sum [Variable "a",Number 4]]
+
+--pat = Sum [Capture "x",Negative (Capture "x"),CaptureTerm "y",CaptureTerm "z"]
+--sub = let x = Reciprocal (Number 7) in Sum [x,Negative x,Number 7,Number 8]
+
+--pat = Sum [CaptureTerm "x",CaptureTerm "y",Capture "rest",Negative $ Capture "rest"]
+--sub = Sum [Number 1,Number 2,Negative $ Number 1,Variable "kaas",Negative $ Sum [Negative $ Number 1,Variable "kaas"]]
+
+pat = Sum [Product [Capture "x"],Product [Capture "x"]]
+sub = Sum [Product [Number 1],Product [Number 1]]
+
+main = do
+ let res = astMatch pat sub
+ deepseq res $ putStrLn $ "\x1B[32m"++show res++"\x1B[0m"
diff --git a/debug.hs b/debug.hs
new file mode 100644
index 0000000..33ca1b0
--- /dev/null
+++ b/debug.hs
@@ -0,0 +1,23 @@
+{-# LANGUAGE CPP #-}
+
+module Debug where
+
+#if 1
+
+import qualified Debug.Trace as Trace
+
+trace :: String -> a -> a
+trace = Trace.trace
+
+traceShow :: (Show a) => a -> b -> b
+traceShow = Trace.traceShow
+
+#else
+
+trace :: String -> a -> a
+trace = flip const
+
+traceShow :: a -> b -> b
+traceShow = flip const
+
+#endif
diff --git a/main.hs b/main.hs
new file mode 100644
index 0000000..5beb0e4
--- /dev/null
+++ b/main.hs
@@ -0,0 +1,31 @@
+module Main where
+
+import Control.Monad
+import Data.Either
+import Data.Maybe
+import System.Console.Readline
+
+import Simplify
+import Parser
+import PrettyPrint
+
+
+repl :: IO ()
+repl = do
+ mline <- readline "> "
+ case mline of
+ Nothing -> return () -- EOF
+ Just "" -> repl
+ Just line -> do
+ addHistory line
+ let eexpr = parseExpression line
+ either (putStrLn . ("Error: "++)) handleExpression eexpr
+ repl
+ where
+ handleExpression expr = do
+ print expr
+ let sim = simplify expr
+ print sim
+
+main :: IO ()
+main = repl
diff --git a/parser.hs b/parser.hs
new file mode 100644
index 0000000..6d8ed6d
--- /dev/null
+++ b/parser.hs
@@ -0,0 +1,236 @@
+module Parser (parseExpression) where
+
+import Control.Applicative
+import Control.Monad
+import Data.Char
+import Data.Maybe
+
+import AST
+import Utility
+
+
+parseExpression :: String -> Either String AST
+parseExpression s = case parse pexpression s of
+ ((node,rest):_) -> case rest of
+ "" -> Right node
+ s -> Left $ "Cannot parse from '" ++ take 10 rest ++ "'"
+ _ -> Left "No valid parse"
+
+
+newtype Parser a = Parser (String -> [(a,String)])
+
+parse :: Parser a -> String -> [(a,String)]
+parse (Parser p) = p
+
+instance Functor Parser where
+ fmap f p = Parser (\cs -> map (\(a,s) -> (f a,s)) $ parse p cs)
+
+instance Applicative Parser where
+ pure x = Parser (\cs -> [(x,cs)])
+ (<*>) f p = Parser (\cs -> concat $
+ map (\(a,s) -> parse (fmap a p) s) $ parse f cs)
+
+instance Monad Parser where
+ p >>= f = Parser (\cs -> concat $
+ map (\(a,s) -> parse (f a) s) $ parse p cs)
+
+instance Alternative Parser where
+ empty = Parser (\_ -> [])
+ (<|>) p q = Parser (\cs -> parse p cs ++ parse q cs)
+
+instance MonadPlus Parser
+
+
+--The deterministic choice operator: choose the first possibile parse (if
+--available at all) from the results given by the two parsers.
+--mplus is the non-deterministic choice operator; it would give all results.
+mplus1 :: Parser a -> Parser a -> Parser a
+mplus1 p q = Parser $ \cs -> case parse (mplus p q) cs of
+ [] -> []
+ (x:_) -> [x]
+
+--(++) = mplus
+(+++) = mplus1
+
+
+pitem :: Parser Char
+pitem = Parser $ \s -> case s of
+ "" -> []
+ (c:cs) -> [(c,cs)]
+
+psat :: (Char -> Bool) -> Parser Char
+psat f = do
+ c <- pitem
+ if f c then return c else mzero
+
+--checks that the next char satisfies the predicate; does NOT consume characters
+passert :: (Char -> Bool) -> Parser ()
+passert p = Parser $ \s -> case s of
+ "" -> []
+ (c:_) -> if p c then [((),s)] else []
+
+pchar :: Char -> Parser Char
+pchar c = psat (==c)
+
+pstring :: String -> Parser String
+pstring "" = return ""
+pstring (c:cs) = do
+ pchar c
+ pstring cs
+ return (c:cs)
+
+pmany :: Parser a -> Parser [a]
+pmany p = pmany1 p +++ return []
+
+pmany1 :: Parser a -> Parser [a]
+pmany1 p = do
+ a <- p
+ as <- pmany p
+ return (a:as)
+
+pinteger :: Parser Int
+pinteger = do
+ s <- pmany $ psat isDigit
+ return $ read s
+
+pdouble :: Parser Double
+pdouble = Parser reads
+
+pquotstring :: Parser String
+pquotstring = Parser reads
+
+poptws :: Parser String
+poptws = Parser $ pure . span isSpace
+
+pws :: Parser String
+pws = Parser $ \s -> case span isSpace s of
+ ("",_) -> []
+ tup@(_,_) -> [tup]
+
+pword :: Parser String
+pword = do
+ c <- psat $ options [isAlpha,(=='_')]
+ cs <- pmany $ psat $ options [isAlpha,isDigit,(=='_')]
+ return (c:cs)
+
+
+pnumber :: Parser AST
+pnumber = liftM Number pdouble
+
+pvariable :: Parser AST
+pvariable = liftM Variable $ pstring "PI" +++ (liftM pure (psat isAlpha))
+
+pinfixoperator :: (Char,Char) -- +/- symbols
+ -> Parser AST -- term parser
+ -> ([AST] -> AST) -- Sum constructor
+ -> (AST -> AST) -- Negative constructor
+ -> Bool -- whether the plus sign is optional
+ -> Bool -- whether a negative sign cannot follow after a term
+ -> Parser AST -- Resulting parser
+pinfixoperator (plus,minus) pterm sumconstr negconstr plusopt noneg = do
+ term <- pterm
+ pmoreterms term +++ return (sumconstr [term])
+ where
+ pmoreterms term = if plusopt
+ then pmoretermsplus term +++ pmoretermsminus term +++ pmoretermsnothing term
+ else pmoretermsplus term +++ pmoretermsminus term
+
+ pmoretermsplus term = do
+ poptws
+ pchar plus
+ poptws
+ nextterm <- pterm
+ let thissum = sumconstr [term,nextterm]
+ pmoreterms thissum +++ return thissum
+ pmoretermsminus term = do
+ poptws
+ pchar minus
+ poptws
+ nextterm <- pterm
+ let thissum = sumconstr [term,negconstr nextterm]
+ pmoreterms thissum +++ return thissum
+ pmoretermsnothing term = do
+ poptws
+ if noneg then passert (/='-') else return ()
+ nextterm <- pterm
+ let thissum = sumconstr [term,nextterm]
+ pmoreterms thissum +++ return thissum
+
+psum :: Parser AST
+psum = pinfixoperator ('+','-') pproduct Sum Negative False False
+
+pproduct :: Parser AST
+pproduct = pinfixoperator ('*','/') pfactor Product Reciprocal True True
+
+pfactor :: Parser AST
+pfactor = pnegative +++ pfactornoneg +++ pcapture +++ pcaptureterm
+
+pnegative :: Parser AST
+pnegative = do {pchar '-'; poptws; f <- pfactor; return $ Negative f} +++ pfactornoneg
+
+pfactornoneg :: Parser AST
+pfactornoneg = do
+ fact <- pnumber +++ pparenthetical +++ pfunctioncall +++ pvariable
+ ppower fact +++ pfactorial fact +++ return fact
+ where
+ ppower fact = do
+ poptws
+ pchar '^'
+ poptws
+ fact2 <- pfactornoneg
+ return $ Apply "pow" [fact,fact2]
+ pfactorial fact = do
+ poptws
+ pchar '!'
+ return $ Apply "fact" [fact]
+
+
+pparenthetical :: Parser AST
+pparenthetical = do
+ pchar '('
+ poptws
+ sum <- psum
+ poptws
+ pchar ')'
+ return sum
+
+pfunctioncall :: Parser AST
+pfunctioncall = do
+ name <- pword
+ poptws
+ pchar '('
+ poptws
+ args <- parglist
+ poptws
+ pchar ')'
+ return $ Apply name args
+ where
+ parglist = do
+ arg <- parg
+ poptws
+ pmoreargs arg +++ return [arg]
+ pmoreargs arg = do
+ pchar ','
+ poptws
+ args <- parglist
+ return (arg:args)
+ parg = pexpression
+
+pcapture :: Parser AST
+pcapture = do
+ pchar '['
+ name <- pmany1 $ psat (/=']')
+ pchar ']'
+ return $ Capture name
+
+pcaptureterm :: Parser AST
+pcaptureterm = do
+ pchar '['
+ pchar '['
+ name <- pmany1 $ psat (/=']')
+ pchar ']'
+ pchar ']'
+ return $ CaptureTerm name
+
+pexpression :: Parser AST
+pexpression = psum
diff --git a/prettyprint.hs b/prettyprint.hs
new file mode 100644
index 0000000..45ca6ac
--- /dev/null
+++ b/prettyprint.hs
@@ -0,0 +1,11 @@
+{-# LANGUAGE FlexibleInstances #-}
+
+module PrettyPrint where
+
+class PrettyPrint a where
+ prettyPrint :: a -> String
+ -- a = a
+
+instance PrettyPrint String where prettyPrint = id
+instance PrettyPrint Double where prettyPrint = show
+instance PrettyPrint Int where prettyPrint = show
diff --git a/simplify.hs b/simplify.hs
new file mode 100644
index 0000000..cb3e5af
--- /dev/null
+++ b/simplify.hs
@@ -0,0 +1,172 @@
+{-# LANGUAGE CPP #-}
+
+module Simplify (simplify) where
+
+import Data.List
+import qualified Data.Map.Strict as Map
+
+import AST
+import Utility
+
+import Debug
+import PrettyPrint
+
+tracex :: (Show a) => String -> a -> a
+tracex s x = trace (s ++ ": " ++ show x) x
+
+tracexp :: (PrettyPrint a) => String -> a -> a
+tracexp s x = trace (s ++ ": " ++ prettyPrint x) x
+
+
+simplify :: AST -> AST
+simplify = tracex "last canonicaliseOrder" . canonicaliseOrder
+ . (fixpoint $ tracex "applyPatterns " . applyPatterns
+ . tracex "flattenSums " . flattenSums
+ -- . tracex "collectLikeTerms " . collectLikeTerms
+ . tracex "canonicaliseOrder" . canonicaliseOrder
+ . tracex "foldNumbers " . foldNumbers)
+ . tracex "first flattenSums" . flattenSums
+
+
+flattenSums :: AST -> AST
+flattenSums node = case node of
+ (Negative n) -> Negative $ flattenSums n
+ (Reciprocal n) -> Reciprocal $ flattenSums n
+ (Apply name args) -> Apply name $ map flattenSums args
+ (Sum args) -> case length args of
+ 0 -> Number 0
+ 1 -> flattenSums $ args !! 0
+ otherwise -> Sum $ concat $ map (listify . flattenSums) args
+ where
+ listify (Sum args) = args
+ listify node = [node]
+ (Product args) -> case length args of
+ 0 -> Number 1
+ 1 -> flattenSums $ args !! 0
+ otherwise -> Product $ concat $ map (listify . flattenSums) args
+ where
+ listify (Product args) = args
+ listify node = [node]
+ _ -> node
+
+foldNumbers :: AST -> AST
+foldNumbers node = case node of
+ (Negative n) -> let fn = foldNumbers n in case fn of
+ (Number x) -> Number (-x)
+ (Negative n2) -> n2
+ (Product args) -> Product $ Number (-1) : args
+ _ -> Negative $ fn
+ (Reciprocal n) -> let fn = foldNumbers n in case fn of
+ (Number x) -> Number (1/x)
+ (Negative n) -> Negative $ Reciprocal fn
+ (Reciprocal n2) -> n2
+ _ -> Reciprocal $ fn
+ (Apply name args) -> Apply name $ map foldNumbers args
+ (Sum args) -> Sum $ dofoldnums sum args 0
+ (Product args) -> dofoldnegsToProd $ dofoldnums product args 1
+ _ -> node
+ where
+ dofoldnums func args zerovalue =
+ let foldedArgs = map foldNumbers args
+ (nums,notnums) = partition astIsNumber foldedArgs
+ foldvalue = func $ map (\(Number n) -> n) nums
+ in case length nums of
+ x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums
+ otherwise -> foldedArgs
+ dofoldnegsToProd args =
+ let foldedArgs = map foldNumbers args
+ (negs,notnegs) = partition isneg foldedArgs
+ isneg (Negative _) = True
+ isneg (Number n) = n < 0
+ isneg _ = False
+ unneg (Negative n) = n
+ unneg (Number n) = Number $ abs n
+ unneg n = n
+ unnegged = map unneg negs ++ notnegs
+ in case length negs of
+ x | x < 2 -> Product args
+ | even x -> Product unnegged
+ | odd x -> Product $ Number (-1) : unnegged
+
+-- collectLikeTerms :: AST -> AST
+-- collectLikeTerms node = case node of
+-- (Reciprocal n) -> Apply "pow" [n,Number $ -1]
+-- (Product args) ->
+-- let ispow (Apply "pow" _) = True
+-- ispow _ = False
+-- (pows,nopows) = partition ispow $ map collectLikeTerms args
+-- groups = groupBy (\(Apply _ [x,_]) (Apply _ [y,_]) -> x == y) pows
+-- baseof (Apply _ [x,_]) = x
+-- expof (Apply _ [_,x]) = x
+-- collectGroup l = Apply "pow" [baseof (l!!0),Sum $ map expof l]
+-- in Product $ map collectGroup groups ++ nopows
+-- (Sum args) ->
+-- let isnumterm (Product (Number _:_)) = True
+-- isnumterm _ = False
+-- (numterms,nonumterms) = partition isnumterm $ map collectLikeTerms args
+-- groups = groupBy (\(Product (Number _:xs)) (Product (Number _:ys))
+-- -> astMatchSimple (Product xs) (Product ys))
+-- numterms
+-- numof (Product (n:_)) = n
+-- restof (Product (_:rest)) = rest
+-- collectGroup l =
+-- if length l == 1
+-- then l!!0
+-- else Product $ Sum (map numof l) : restof (l!!0)
+-- in Sum $ map collectGroup groups ++ nonumterms
+-- _ -> node
+
+canonicaliseOrder :: AST -> AST
+canonicaliseOrder node = case node of
+ (Number _) -> node
+ (Variable _) -> node
+ (Sum args) -> Sum $ sort args
+ (Product args) -> Product $ sort args
+ (Negative n) -> Negative $ canonicaliseOrder n
+ (Reciprocal n) -> Reciprocal $ canonicaliseOrder n
+ (Apply name args) -> Apply name $ map canonicaliseOrder args
+ (Capture _) -> node
+ (CaptureTerm _) -> node
+ (CaptureConstr _ _) -> node
+
+
+patterndb :: [(AST,AST)]
+patterndb = [
+ (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"], -- x + x + [rest] -> 2*x + [rest]
+ Sum [Product [Number 2,Capture "x"],Capture "rest"]),
+
+ (Sum [CaptureTerm "x", -- x + n*x + [rest] -> (1+n)*x + [rest]
+ Product [CaptureConstr "n" (Number undefined),Capture "x"],
+ Capture "rest"],
+ Sum [Product [Sum [Number 1,Capture "n"],
+ Capture "x"],
+ Capture "rest"]),
+
+ (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"], -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest]
+ Product [CaptureConstr "n2" (Number undefined),Capture "x"],
+ Capture "rest"],
+
+ Sum [Product [Sum [Capture "n1",Capture "n2"],
+ Capture "x"],
+ Capture "rest"])
+ ]
+
+astChildMap :: AST -> (AST -> AST) -> AST
+astChildMap node f = case node of
+ (Number _) -> node
+ (Variable _) -> node
+ (Sum args) -> Sum $ map f args
+ (Product args) -> Product $ map f args
+ (Negative n) -> Negative $ f n
+ (Reciprocal n) -> Reciprocal $ f n
+ (Apply name args) -> Apply name $ map f args
+ (Capture _) -> node
+ (CaptureTerm _) -> node
+ (CaptureConstr _ _) -> node
+
+applyPatterns :: AST -> AST
+applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb
+ in if null matches
+ then astChildMap node applyPatterns
+ else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all
+ in replaceCaptures capdict repl
diff --git a/utility.hs b/utility.hs
new file mode 100644
index 0000000..c442f80
--- /dev/null
+++ b/utility.hs
@@ -0,0 +1,20 @@
+module Utility where
+
+import Data.List
+
+--if any of the predicates returns true, options returns true
+options :: [a -> Bool] -> a -> Bool
+options l x = any id [f x | f <- l]
+
+fixpoint :: (Eq a) => (a -> a) -> a -> a
+fixpoint f x = let fx = f x in if fx == x then x else fixpoint f fx
+
+setcompareBy :: (a -> a -> Bool) -> [a] -> [a] -> Bool
+setcompareBy _ [] [] = True
+setcompareBy _ [] _ = False
+setcompareBy _ _ [] = False
+setcompareBy p a@(x:xs) b = length a == length b && setcompareBy p xs (deleteBy p x b)
+
+deleteIndex :: Int -> [a] -> [a]
+deleteIndex 0 (_:xs) = xs
+deleteIndex i (x:xs) | i > 0 = x : deleteIndex (i-1) xs