From dd3db844dd49451f28d044cd1d2fd71430d73686 Mon Sep 17 00:00:00 2001 From: tomsmeding Date: Thu, 16 Jun 2016 23:24:47 +0200 Subject: Initial --- .gitignore | 4 + Makefile | 16 ++++ ast.hs | 283 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ debug.hs | 23 +++++ main.hs | 31 +++++++ parser.hs | 236 +++++++++++++++++++++++++++++++++++++++++++++++ prettyprint.hs | 11 +++ simplify.hs | 172 +++++++++++++++++++++++++++++++++++ utility.hs | 20 ++++ 9 files changed, 796 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 ast.hs create mode 100644 debug.hs create mode 100644 main.hs create mode 100644 parser.hs create mode 100644 prettyprint.hs create mode 100644 simplify.hs create mode 100644 utility.hs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c67b73 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.o +*.hi +*.hs[0-9] +main diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6d4ca8f --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +BIN = main + +hs_files = $(wildcard *.hs) + +.PHONY: all clean remake + +all: $(BIN) + +clean: + rm -f *.hi *.o $(BIN) + +remake: clean all + + +$(BIN): $(hs_files) + ghc -O3 -o $(BIN) $^ diff --git a/ast.hs b/ast.hs new file mode 100644 index 0000000..1e267c1 --- /dev/null +++ b/ast.hs @@ -0,0 +1,283 @@ +{-# LANGUAGE TupleSections, BangPatterns, DeriveDataTypeable #-} + +module AST where + +import qualified Data.Map.Strict as Map +import Data.List +import Data.Data +import Data.Typeable +import Control.DeepSeq + +import PrettyPrint +import Debug + + +data AST = Number Double + | Variable String + | Sum [AST] + | Product [AST] + | Negative AST + | Reciprocal AST + | Apply String [AST] + -- The following only in patterns: + | Capture String + | CaptureTerm String + | CaptureConstr String AST -- AST argument only for constructor; only WHNF + deriving (Eq,Typeable,Data) + +instance PrettyPrint AST where + prettyPrint (Number n) = show n + + prettyPrint (Variable n) = n + + prettyPrint (Sum []) = "(+)" + prettyPrint (Sum args) = intercalate " + " $ map prettyPrint args + + prettyPrint (Product []) = "(*)" + prettyPrint (Product args) = intercalate "*" $ map gopp args + where gopp s@(Sum _) = '(' : prettyPrint s ++ ")" + gopp n = prettyPrint n + + prettyPrint (Negative n) = '-' : case n of + s@(Sum _) -> '(' : prettyPrint s ++ ")" + n -> prettyPrint n + + prettyPrint (Reciprocal n) = "1/" ++ case n of + s@(Sum _) -> '(' : prettyPrint s ++ ")" + s@(Product _) -> '(' : prettyPrint s ++ ")" + s@(Reciprocal _) -> '(' : prettyPrint s ++ ")" + n -> prettyPrint n + + prettyPrint (Apply name args) = name ++ "(" ++ intercalate "," (map prettyPrint args) ++ ")" + + prettyPrint (Capture name) = '[' : name ++ "]" + + prettyPrint (CaptureTerm name) = '[' : '[' : name ++ "]]" + + prettyPrint (CaptureConstr name c) = '[' : name ++ ":" ++ showConstr (toConstr c) ++ "]" + +instance Show AST where + show = prettyPrint + +instance NFData AST where + rnf (Number !_) = () + rnf (Variable !_) = () + rnf (Sum l) = seq (length $ map rnf l) () + rnf (Product l) = seq (length $ map rnf l) () + rnf (Negative n) = rnf n + rnf (Reciprocal n) = rnf n + rnf (Apply !_ l) = seq (length $ map rnf l) () + rnf (Capture !_) = () + rnf (CaptureTerm !_) = () + rnf (CaptureConstr !_ !_) = () -- explicitly not deepseq'ing the ast node + +instance Ord AST where + compare (Number _) (Number _) = EQ + compare (Variable _) (Variable _) = EQ + compare (Sum _) (Sum _) = EQ + compare (Product _) (Product _) = EQ + compare (Negative _) (Negative _) = EQ + compare (Reciprocal _) (Reciprocal _) = EQ + compare (Apply _ _) (Apply _ _) = EQ + compare (Capture _) (Capture _) = EQ + compare (CaptureTerm _) (CaptureTerm _) = EQ + compare (CaptureConstr _ _) (CaptureConstr _ _) = EQ + + compare (Capture _) _ = LT -- Unbounded captures first for efficient + compare _ (Capture _) = GT -- extraction with span isCapture + compare (Number _) _ = LT + compare _ (Number _) = GT + compare (Variable _) _ = LT + compare _ (Variable _) = GT + compare (Sum _) _ = LT + compare _ (Sum _) = GT + compare (Product _) _ = LT + compare _ (Product _) = GT + compare (Negative _) _ = LT + compare _ (Negative _) = GT + compare (Reciprocal _) _ = LT + compare _ (Reciprocal _) = GT + compare (Apply _ _) _ = LT + compare _ (Apply _ _) = GT + compare (CaptureTerm _) _ = LT + compare _ (CaptureTerm _) = GT + -- compare (CaptureConstr _ _) _ = LT + -- compare _ (CaptureConstr _ _) = GT + + +astIsNumber :: AST -> Bool +astIsNumber (Number _) = True +astIsNumber _ = False + +astIsCapture :: AST -> Bool +astIsCapture (Capture _) = True +astIsCapture _ = False + + +astMatchSimple :: AST -> AST -> Bool +astMatchSimple pat sub = let res = {-(\x -> trace (" !! RESULT: " ++ show x ++ " !! ") x) $-} astMatch pat sub + in if null res + then False + else any Map.null res + + +astMatch :: AST -- pattern + -> AST -- subject + -> [Map.Map String AST] -- list of possible capture assignments +astMatch pat sub = assertS "No captures in astMatch subject" (not $ hasCaptures sub) $ + case pat of + Number x -> case sub of + Number y | x == y -> [Map.empty] + _ -> [] + + Variable name -> case sub of + Variable name2 | name == name2 -> [Map.empty] + _ -> [] + + Sum [term] -> case sub of + Sum l2 -> matchList Sum [term] l2 + s -> astMatch term s + + Sum l -> case sub of + Sum l2 -> matchList Sum l l2 + _ -> [] + + Product [term] -> case sub of + Product l2 -> matchList Product [term] l2 + s -> astMatch term s + + Product l -> case sub of + Product l2 -> matchList Product l l2 + _ -> [] + + Negative n -> case sub of + Negative n2 -> astMatch n n2 + _ -> [] + + Reciprocal n -> case sub of + Reciprocal n2 -> astMatch n n2 + _ -> [] + + Apply name l -> case sub of + Apply name l2 -> matchOrderedList l l2 + _ -> [] + + Capture name -> [Map.singleton name sub] + + CaptureTerm name -> [Map.singleton name sub] + + CaptureConstr name constr -> + if toConstr sub == toConstr constr + then [Map.singleton name sub] + else [] + + +matchList :: ([AST] -> AST) -- AST constructor for this list (for insertion in capture) + -> [AST] -- unordered patterns + -> [AST] -- unordered subjects + -> [Map.Map String AST] -- list of possible capture assignments +matchList constr pats subs = + let ordered = sort pats + (captures,nocaps) = span astIsCapture ordered + in assertS "At most one capture in sum/product" (length captures <= 1) $ case captures of + [] -> matchListDBG Nothing nocaps subs + [c] -> matchListDBG (Just c) nocaps subs + where matchList' :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] + matchList' Nothing [] [] = [Map.empty] + matchList' Nothing [] _ = [] + matchList' (Just (Capture name)) [] subs = [Map.singleton name $ constr subs] + matchList' (Just node) [] subs = astMatch node (constr subs) + matchList' mcap (pat:pats) subs = + let firstmatches = concat $ mapDel (\s other -> map (,other) $ astMatch pat s) subs + processed = concat + $ map (\(ass,rest) -> + let replpats = map (replaceCaptures ass) pats + replmcap = fmap (replaceCaptures ass) mcap + in map (Map.union ass) $ matchListDBG replmcap replpats rest) + firstmatches + in {-trace ("firstmatches = "++show firstmatches) $ trace ("processed = "++show processed) $-} processed + + matchListDBG :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] + matchListDBG mcap pats subs = {-force $ trace ("\n<< "++show (mcap,pats,subs)++" >>\n") + $-} matchList' mcap pats subs + + +matchOrderedList :: [AST] -- ordered patterns + -> [AST] -- ordered subjects + -> [Map.Map String AST] -- list of possible capture assignments +matchOrderedList [] [] = [Map.empty] +matchOrderedList [] _ = [] +matchOrderedList _ [] = [] +matchOrderedList (pat:pats) (sub:subs) = + let opts = astMatch pat sub + newpatsopts = [(map (replaceCaptures opt) pats,opt) | opt <- opts] + -- ^ list of possible refined versions of the (rest of the) pattern list + in {-trace (show (pat:pats) ++ " -- " ++ show (sub:subs)) $ traceShow opts $-} + concat $ map (\(newpats,opt) -> map (Map.union opt) + $ matchOrderedList newpats subs) newpatsopts + + +replaceCaptures :: Map.Map String AST -> AST -> AST +replaceCaptures mp n = case n of + Number _ -> n + Variable _ -> n + Sum l -> Sum $ map (replaceCaptures mp) l + Product l -> Product $ map (replaceCaptures mp) l + Negative n2 -> Negative $ replaceCaptures mp n2 + Reciprocal n2 -> Reciprocal $ replaceCaptures mp n2 + Apply name n2 -> Apply name $ map (replaceCaptures mp) n2 + Capture name -> maybe n id $ Map.lookup name mp + CaptureTerm name -> maybe n id $ Map.lookup name mp + CaptureConstr name c -> maybe n id $ Map.lookup name mp + + +hasCaptures :: AST -> Bool +hasCaptures n = case n of + Number _ -> False + Variable _ -> False + Sum l -> any id [hasCaptures m | m <- l] + Product l -> any id [hasCaptures m | m <- l] + Negative m -> hasCaptures m + Reciprocal m -> hasCaptures m + Apply _ l -> any id [hasCaptures m | m <- l] + Capture _ -> True + CaptureTerm _ -> True + CaptureConstr _ _ -> True + + +assert :: Bool -> a -> a +assert = assertS "(no reason)" + +assertS :: String -> Bool -> a -> a +assertS _ True = id +assertS s False = error $ "Condition not satisfied in assert: " ++ s + + +mapDel :: (a -> [a] -> b) -> [a] -> [b] +mapDel _ [] = [] +mapDel f l = + let splits = zip l + $ map (\(a,b:bs) -> a++bs) + $ iterate (\(a,b:bs) -> (a++[b],bs)) ([],l) + in map (uncurry f) splits + + +-- some testing things +--pat = Sum [Number 1,Capture "x",Negative $ Capture "x"] +--sub = Sum [Number 4,Variable "a",Number 1,Negative $ Sum [Variable "a",Number 4]] + +--pat = Sum [Negative $ Capture "x"] +--sub = Sum [Negative $ Sum [Variable "a",Number 4]] + +--pat = Sum [Capture "x",Negative (Capture "x"),CaptureTerm "y",CaptureTerm "z"] +--sub = let x = Reciprocal (Number 7) in Sum [x,Negative x,Number 7,Number 8] + +--pat = Sum [CaptureTerm "x",CaptureTerm "y",Capture "rest",Negative $ Capture "rest"] +--sub = Sum [Number 1,Number 2,Negative $ Number 1,Variable "kaas",Negative $ Sum [Negative $ Number 1,Variable "kaas"]] + +pat = Sum [Product [Capture "x"],Product [Capture "x"]] +sub = Sum [Product [Number 1],Product [Number 1]] + +main = do + let res = astMatch pat sub + deepseq res $ putStrLn $ "\x1B[32m"++show res++"\x1B[0m" diff --git a/debug.hs b/debug.hs new file mode 100644 index 0000000..33ca1b0 --- /dev/null +++ b/debug.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE CPP #-} + +module Debug where + +#if 1 + +import qualified Debug.Trace as Trace + +trace :: String -> a -> a +trace = Trace.trace + +traceShow :: (Show a) => a -> b -> b +traceShow = Trace.traceShow + +#else + +trace :: String -> a -> a +trace = flip const + +traceShow :: a -> b -> b +traceShow = flip const + +#endif diff --git a/main.hs b/main.hs new file mode 100644 index 0000000..5beb0e4 --- /dev/null +++ b/main.hs @@ -0,0 +1,31 @@ +module Main where + +import Control.Monad +import Data.Either +import Data.Maybe +import System.Console.Readline + +import Simplify +import Parser +import PrettyPrint + + +repl :: IO () +repl = do + mline <- readline "> " + case mline of + Nothing -> return () -- EOF + Just "" -> repl + Just line -> do + addHistory line + let eexpr = parseExpression line + either (putStrLn . ("Error: "++)) handleExpression eexpr + repl + where + handleExpression expr = do + print expr + let sim = simplify expr + print sim + +main :: IO () +main = repl diff --git a/parser.hs b/parser.hs new file mode 100644 index 0000000..6d8ed6d --- /dev/null +++ b/parser.hs @@ -0,0 +1,236 @@ +module Parser (parseExpression) where + +import Control.Applicative +import Control.Monad +import Data.Char +import Data.Maybe + +import AST +import Utility + + +parseExpression :: String -> Either String AST +parseExpression s = case parse pexpression s of + ((node,rest):_) -> case rest of + "" -> Right node + s -> Left $ "Cannot parse from '" ++ take 10 rest ++ "'" + _ -> Left "No valid parse" + + +newtype Parser a = Parser (String -> [(a,String)]) + +parse :: Parser a -> String -> [(a,String)] +parse (Parser p) = p + +instance Functor Parser where + fmap f p = Parser (\cs -> map (\(a,s) -> (f a,s)) $ parse p cs) + +instance Applicative Parser where + pure x = Parser (\cs -> [(x,cs)]) + (<*>) f p = Parser (\cs -> concat $ + map (\(a,s) -> parse (fmap a p) s) $ parse f cs) + +instance Monad Parser where + p >>= f = Parser (\cs -> concat $ + map (\(a,s) -> parse (f a) s) $ parse p cs) + +instance Alternative Parser where + empty = Parser (\_ -> []) + (<|>) p q = Parser (\cs -> parse p cs ++ parse q cs) + +instance MonadPlus Parser + + +--The deterministic choice operator: choose the first possibile parse (if +--available at all) from the results given by the two parsers. +--mplus is the non-deterministic choice operator; it would give all results. +mplus1 :: Parser a -> Parser a -> Parser a +mplus1 p q = Parser $ \cs -> case parse (mplus p q) cs of + [] -> [] + (x:_) -> [x] + +--(++) = mplus +(+++) = mplus1 + + +pitem :: Parser Char +pitem = Parser $ \s -> case s of + "" -> [] + (c:cs) -> [(c,cs)] + +psat :: (Char -> Bool) -> Parser Char +psat f = do + c <- pitem + if f c then return c else mzero + +--checks that the next char satisfies the predicate; does NOT consume characters +passert :: (Char -> Bool) -> Parser () +passert p = Parser $ \s -> case s of + "" -> [] + (c:_) -> if p c then [((),s)] else [] + +pchar :: Char -> Parser Char +pchar c = psat (==c) + +pstring :: String -> Parser String +pstring "" = return "" +pstring (c:cs) = do + pchar c + pstring cs + return (c:cs) + +pmany :: Parser a -> Parser [a] +pmany p = pmany1 p +++ return [] + +pmany1 :: Parser a -> Parser [a] +pmany1 p = do + a <- p + as <- pmany p + return (a:as) + +pinteger :: Parser Int +pinteger = do + s <- pmany $ psat isDigit + return $ read s + +pdouble :: Parser Double +pdouble = Parser reads + +pquotstring :: Parser String +pquotstring = Parser reads + +poptws :: Parser String +poptws = Parser $ pure . span isSpace + +pws :: Parser String +pws = Parser $ \s -> case span isSpace s of + ("",_) -> [] + tup@(_,_) -> [tup] + +pword :: Parser String +pword = do + c <- psat $ options [isAlpha,(=='_')] + cs <- pmany $ psat $ options [isAlpha,isDigit,(=='_')] + return (c:cs) + + +pnumber :: Parser AST +pnumber = liftM Number pdouble + +pvariable :: Parser AST +pvariable = liftM Variable $ pstring "PI" +++ (liftM pure (psat isAlpha)) + +pinfixoperator :: (Char,Char) -- +/- symbols + -> Parser AST -- term parser + -> ([AST] -> AST) -- Sum constructor + -> (AST -> AST) -- Negative constructor + -> Bool -- whether the plus sign is optional + -> Bool -- whether a negative sign cannot follow after a term + -> Parser AST -- Resulting parser +pinfixoperator (plus,minus) pterm sumconstr negconstr plusopt noneg = do + term <- pterm + pmoreterms term +++ return (sumconstr [term]) + where + pmoreterms term = if plusopt + then pmoretermsplus term +++ pmoretermsminus term +++ pmoretermsnothing term + else pmoretermsplus term +++ pmoretermsminus term + + pmoretermsplus term = do + poptws + pchar plus + poptws + nextterm <- pterm + let thissum = sumconstr [term,nextterm] + pmoreterms thissum +++ return thissum + pmoretermsminus term = do + poptws + pchar minus + poptws + nextterm <- pterm + let thissum = sumconstr [term,negconstr nextterm] + pmoreterms thissum +++ return thissum + pmoretermsnothing term = do + poptws + if noneg then passert (/='-') else return () + nextterm <- pterm + let thissum = sumconstr [term,nextterm] + pmoreterms thissum +++ return thissum + +psum :: Parser AST +psum = pinfixoperator ('+','-') pproduct Sum Negative False False + +pproduct :: Parser AST +pproduct = pinfixoperator ('*','/') pfactor Product Reciprocal True True + +pfactor :: Parser AST +pfactor = pnegative +++ pfactornoneg +++ pcapture +++ pcaptureterm + +pnegative :: Parser AST +pnegative = do {pchar '-'; poptws; f <- pfactor; return $ Negative f} +++ pfactornoneg + +pfactornoneg :: Parser AST +pfactornoneg = do + fact <- pnumber +++ pparenthetical +++ pfunctioncall +++ pvariable + ppower fact +++ pfactorial fact +++ return fact + where + ppower fact = do + poptws + pchar '^' + poptws + fact2 <- pfactornoneg + return $ Apply "pow" [fact,fact2] + pfactorial fact = do + poptws + pchar '!' + return $ Apply "fact" [fact] + + +pparenthetical :: Parser AST +pparenthetical = do + pchar '(' + poptws + sum <- psum + poptws + pchar ')' + return sum + +pfunctioncall :: Parser AST +pfunctioncall = do + name <- pword + poptws + pchar '(' + poptws + args <- parglist + poptws + pchar ')' + return $ Apply name args + where + parglist = do + arg <- parg + poptws + pmoreargs arg +++ return [arg] + pmoreargs arg = do + pchar ',' + poptws + args <- parglist + return (arg:args) + parg = pexpression + +pcapture :: Parser AST +pcapture = do + pchar '[' + name <- pmany1 $ psat (/=']') + pchar ']' + return $ Capture name + +pcaptureterm :: Parser AST +pcaptureterm = do + pchar '[' + pchar '[' + name <- pmany1 $ psat (/=']') + pchar ']' + pchar ']' + return $ CaptureTerm name + +pexpression :: Parser AST +pexpression = psum diff --git a/prettyprint.hs b/prettyprint.hs new file mode 100644 index 0000000..45ca6ac --- /dev/null +++ b/prettyprint.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE FlexibleInstances #-} + +module PrettyPrint where + +class PrettyPrint a where + prettyPrint :: a -> String + -- a = a + +instance PrettyPrint String where prettyPrint = id +instance PrettyPrint Double where prettyPrint = show +instance PrettyPrint Int where prettyPrint = show diff --git a/simplify.hs b/simplify.hs new file mode 100644 index 0000000..cb3e5af --- /dev/null +++ b/simplify.hs @@ -0,0 +1,172 @@ +{-# LANGUAGE CPP #-} + +module Simplify (simplify) where + +import Data.List +import qualified Data.Map.Strict as Map + +import AST +import Utility + +import Debug +import PrettyPrint + +tracex :: (Show a) => String -> a -> a +tracex s x = trace (s ++ ": " ++ show x) x + +tracexp :: (PrettyPrint a) => String -> a -> a +tracexp s x = trace (s ++ ": " ++ prettyPrint x) x + + +simplify :: AST -> AST +simplify = tracex "last canonicaliseOrder" . canonicaliseOrder + . (fixpoint $ tracex "applyPatterns " . applyPatterns + . tracex "flattenSums " . flattenSums + -- . tracex "collectLikeTerms " . collectLikeTerms + . tracex "canonicaliseOrder" . canonicaliseOrder + . tracex "foldNumbers " . foldNumbers) + . tracex "first flattenSums" . flattenSums + + +flattenSums :: AST -> AST +flattenSums node = case node of + (Negative n) -> Negative $ flattenSums n + (Reciprocal n) -> Reciprocal $ flattenSums n + (Apply name args) -> Apply name $ map flattenSums args + (Sum args) -> case length args of + 0 -> Number 0 + 1 -> flattenSums $ args !! 0 + otherwise -> Sum $ concat $ map (listify . flattenSums) args + where + listify (Sum args) = args + listify node = [node] + (Product args) -> case length args of + 0 -> Number 1 + 1 -> flattenSums $ args !! 0 + otherwise -> Product $ concat $ map (listify . flattenSums) args + where + listify (Product args) = args + listify node = [node] + _ -> node + +foldNumbers :: AST -> AST +foldNumbers node = case node of + (Negative n) -> let fn = foldNumbers n in case fn of + (Number x) -> Number (-x) + (Negative n2) -> n2 + (Product args) -> Product $ Number (-1) : args + _ -> Negative $ fn + (Reciprocal n) -> let fn = foldNumbers n in case fn of + (Number x) -> Number (1/x) + (Negative n) -> Negative $ Reciprocal fn + (Reciprocal n2) -> n2 + _ -> Reciprocal $ fn + (Apply name args) -> Apply name $ map foldNumbers args + (Sum args) -> Sum $ dofoldnums sum args 0 + (Product args) -> dofoldnegsToProd $ dofoldnums product args 1 + _ -> node + where + dofoldnums func args zerovalue = + let foldedArgs = map foldNumbers args + (nums,notnums) = partition astIsNumber foldedArgs + foldvalue = func $ map (\(Number n) -> n) nums + in case length nums of + x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums + otherwise -> foldedArgs + dofoldnegsToProd args = + let foldedArgs = map foldNumbers args + (negs,notnegs) = partition isneg foldedArgs + isneg (Negative _) = True + isneg (Number n) = n < 0 + isneg _ = False + unneg (Negative n) = n + unneg (Number n) = Number $ abs n + unneg n = n + unnegged = map unneg negs ++ notnegs + in case length negs of + x | x < 2 -> Product args + | even x -> Product unnegged + | odd x -> Product $ Number (-1) : unnegged + +-- collectLikeTerms :: AST -> AST +-- collectLikeTerms node = case node of +-- (Reciprocal n) -> Apply "pow" [n,Number $ -1] +-- (Product args) -> +-- let ispow (Apply "pow" _) = True +-- ispow _ = False +-- (pows,nopows) = partition ispow $ map collectLikeTerms args +-- groups = groupBy (\(Apply _ [x,_]) (Apply _ [y,_]) -> x == y) pows +-- baseof (Apply _ [x,_]) = x +-- expof (Apply _ [_,x]) = x +-- collectGroup l = Apply "pow" [baseof (l!!0),Sum $ map expof l] +-- in Product $ map collectGroup groups ++ nopows +-- (Sum args) -> +-- let isnumterm (Product (Number _:_)) = True +-- isnumterm _ = False +-- (numterms,nonumterms) = partition isnumterm $ map collectLikeTerms args +-- groups = groupBy (\(Product (Number _:xs)) (Product (Number _:ys)) +-- -> astMatchSimple (Product xs) (Product ys)) +-- numterms +-- numof (Product (n:_)) = n +-- restof (Product (_:rest)) = rest +-- collectGroup l = +-- if length l == 1 +-- then l!!0 +-- else Product $ Sum (map numof l) : restof (l!!0) +-- in Sum $ map collectGroup groups ++ nonumterms +-- _ -> node + +canonicaliseOrder :: AST -> AST +canonicaliseOrder node = case node of + (Number _) -> node + (Variable _) -> node + (Sum args) -> Sum $ sort args + (Product args) -> Product $ sort args + (Negative n) -> Negative $ canonicaliseOrder n + (Reciprocal n) -> Reciprocal $ canonicaliseOrder n + (Apply name args) -> Apply name $ map canonicaliseOrder args + (Capture _) -> node + (CaptureTerm _) -> node + (CaptureConstr _ _) -> node + + +patterndb :: [(AST,AST)] +patterndb = [ + (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"], -- x + x + [rest] -> 2*x + [rest] + Sum [Product [Number 2,Capture "x"],Capture "rest"]), + + (Sum [CaptureTerm "x", -- x + n*x + [rest] -> (1+n)*x + [rest] + Product [CaptureConstr "n" (Number undefined),Capture "x"], + Capture "rest"], + Sum [Product [Sum [Number 1,Capture "n"], + Capture "x"], + Capture "rest"]), + + (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"], -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest] + Product [CaptureConstr "n2" (Number undefined),Capture "x"], + Capture "rest"], + + Sum [Product [Sum [Capture "n1",Capture "n2"], + Capture "x"], + Capture "rest"]) + ] + +astChildMap :: AST -> (AST -> AST) -> AST +astChildMap node f = case node of + (Number _) -> node + (Variable _) -> node + (Sum args) -> Sum $ map f args + (Product args) -> Product $ map f args + (Negative n) -> Negative $ f n + (Reciprocal n) -> Reciprocal $ f n + (Apply name args) -> Apply name $ map f args + (Capture _) -> node + (CaptureTerm _) -> node + (CaptureConstr _ _) -> node + +applyPatterns :: AST -> AST +applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb + in if null matches + then astChildMap node applyPatterns + else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all + in replaceCaptures capdict repl diff --git a/utility.hs b/utility.hs new file mode 100644 index 0000000..c442f80 --- /dev/null +++ b/utility.hs @@ -0,0 +1,20 @@ +module Utility where + +import Data.List + +--if any of the predicates returns true, options returns true +options :: [a -> Bool] -> a -> Bool +options l x = any id [f x | f <- l] + +fixpoint :: (Eq a) => (a -> a) -> a -> a +fixpoint f x = let fx = f x in if fx == x then x else fixpoint f fx + +setcompareBy :: (a -> a -> Bool) -> [a] -> [a] -> Bool +setcompareBy _ [] [] = True +setcompareBy _ [] _ = False +setcompareBy _ _ [] = False +setcompareBy p a@(x:xs) b = length a == length b && setcompareBy p xs (deleteBy p x b) + +deleteIndex :: Int -> [a] -> [a] +deleteIndex 0 (_:xs) = xs +deleteIndex i (x:xs) | i > 0 = x : deleteIndex (i-1) xs -- cgit v1.2.3-70-g09d2