diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | Makefile | 16 | ||||
-rw-r--r-- | ast.hs | 283 | ||||
-rw-r--r-- | debug.hs | 23 | ||||
-rw-r--r-- | main.hs | 31 | ||||
-rw-r--r-- | parser.hs | 236 | ||||
-rw-r--r-- | prettyprint.hs | 11 | ||||
-rw-r--r-- | simplify.hs | 172 | ||||
-rw-r--r-- | utility.hs | 20 |
9 files changed, 796 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c67b73 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.o +*.hi +*.hs[0-9] +main diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6d4ca8f --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +BIN = main + +hs_files = $(wildcard *.hs) + +.PHONY: all clean remake + +all: $(BIN) + +clean: + rm -f *.hi *.o $(BIN) + +remake: clean all + + +$(BIN): $(hs_files) + ghc -O3 -o $(BIN) $^ @@ -0,0 +1,283 @@ +{-# LANGUAGE TupleSections, BangPatterns, DeriveDataTypeable #-} + +module AST where + +import qualified Data.Map.Strict as Map +import Data.List +import Data.Data +import Data.Typeable +import Control.DeepSeq + +import PrettyPrint +import Debug + + +data AST = Number Double + | Variable String + | Sum [AST] + | Product [AST] + | Negative AST + | Reciprocal AST + | Apply String [AST] + -- The following only in patterns: + | Capture String + | CaptureTerm String + | CaptureConstr String AST -- AST argument only for constructor; only WHNF + deriving (Eq,Typeable,Data) + +instance PrettyPrint AST where + prettyPrint (Number n) = show n + + prettyPrint (Variable n) = n + + prettyPrint (Sum []) = "(+)" + prettyPrint (Sum args) = intercalate " + " $ map prettyPrint args + + prettyPrint (Product []) = "(*)" + prettyPrint (Product args) = intercalate "*" $ map gopp args + where gopp s@(Sum _) = '(' : prettyPrint s ++ ")" + gopp n = prettyPrint n + + prettyPrint (Negative n) = '-' : case n of + s@(Sum _) -> '(' : prettyPrint s ++ ")" + n -> prettyPrint n + + prettyPrint (Reciprocal n) = "1/" ++ case n of + s@(Sum _) -> '(' : prettyPrint s ++ ")" + s@(Product _) -> '(' : prettyPrint s ++ ")" + s@(Reciprocal _) -> '(' : prettyPrint s ++ ")" + n -> prettyPrint n + + prettyPrint (Apply name args) = name ++ "(" ++ intercalate "," (map prettyPrint args) ++ ")" + + prettyPrint (Capture name) = '[' : name ++ "]" + + prettyPrint (CaptureTerm name) = '[' : '[' : name ++ "]]" + + prettyPrint (CaptureConstr name c) = '[' : name ++ ":" ++ showConstr (toConstr c) ++ "]" + +instance Show AST where + show = prettyPrint + +instance NFData AST where + rnf (Number !_) = () + rnf (Variable !_) = () + rnf (Sum l) = seq (length $ map rnf l) () + rnf (Product l) = seq (length $ map rnf l) () + rnf (Negative n) = rnf n + rnf (Reciprocal n) = rnf n + rnf (Apply !_ l) = seq (length $ map rnf l) () + rnf (Capture !_) = () + rnf (CaptureTerm !_) = () + rnf (CaptureConstr !_ !_) = () -- explicitly not deepseq'ing the ast node + +instance Ord AST where + compare (Number _) (Number _) = EQ + compare (Variable _) (Variable _) = EQ + compare (Sum _) (Sum _) = EQ + compare (Product _) (Product _) = EQ + compare (Negative _) (Negative _) = EQ + compare (Reciprocal _) (Reciprocal _) = EQ + compare (Apply _ _) (Apply _ _) = EQ + compare (Capture _) (Capture _) = EQ + compare (CaptureTerm _) (CaptureTerm _) = EQ + compare (CaptureConstr _ _) (CaptureConstr _ _) = EQ + + compare (Capture _) _ = LT -- Unbounded captures first for efficient + compare _ (Capture _) = GT -- extraction with span isCapture + compare (Number _) _ = LT + compare _ (Number _) = GT + compare (Variable _) _ = LT + compare _ (Variable _) = GT + compare (Sum _) _ = LT + compare _ (Sum _) = GT + compare (Product _) _ = LT + compare _ (Product _) = GT + compare (Negative _) _ = LT + compare _ (Negative _) = GT + compare (Reciprocal _) _ = LT + compare _ (Reciprocal _) = GT + compare (Apply _ _) _ = LT + compare _ (Apply _ _) = GT + compare (CaptureTerm _) _ = LT + compare _ (CaptureTerm _) = GT + -- compare (CaptureConstr _ _) _ = LT + -- compare _ (CaptureConstr _ _) = GT + + +astIsNumber :: AST -> Bool +astIsNumber (Number _) = True +astIsNumber _ = False + +astIsCapture :: AST -> Bool +astIsCapture (Capture _) = True +astIsCapture _ = False + + +astMatchSimple :: AST -> AST -> Bool +astMatchSimple pat sub = let res = {-(\x -> trace (" !! RESULT: " ++ show x ++ " !! ") x) $-} astMatch pat sub + in if null res + then False + else any Map.null res + + +astMatch :: AST -- pattern + -> AST -- subject + -> [Map.Map String AST] -- list of possible capture assignments +astMatch pat sub = assertS "No captures in astMatch subject" (not $ hasCaptures sub) $ + case pat of + Number x -> case sub of + Number y | x == y -> [Map.empty] + _ -> [] + + Variable name -> case sub of + Variable name2 | name == name2 -> [Map.empty] + _ -> [] + + Sum [term] -> case sub of + Sum l2 -> matchList Sum [term] l2 + s -> astMatch term s + + Sum l -> case sub of + Sum l2 -> matchList Sum l l2 + _ -> [] + + Product [term] -> case sub of + Product l2 -> matchList Product [term] l2 + s -> astMatch term s + + Product l -> case sub of + Product l2 -> matchList Product l l2 + _ -> [] + + Negative n -> case sub of + Negative n2 -> astMatch n n2 + _ -> [] + + Reciprocal n -> case sub of + Reciprocal n2 -> astMatch n n2 + _ -> [] + + Apply name l -> case sub of + Apply name l2 -> matchOrderedList l l2 + _ -> [] + + Capture name -> [Map.singleton name sub] + + CaptureTerm name -> [Map.singleton name sub] + + CaptureConstr name constr -> + if toConstr sub == toConstr constr + then [Map.singleton name sub] + else [] + + +matchList :: ([AST] -> AST) -- AST constructor for this list (for insertion in capture) + -> [AST] -- unordered patterns + -> [AST] -- unordered subjects + -> [Map.Map String AST] -- list of possible capture assignments +matchList constr pats subs = + let ordered = sort pats + (captures,nocaps) = span astIsCapture ordered + in assertS "At most one capture in sum/product" (length captures <= 1) $ case captures of + [] -> matchListDBG Nothing nocaps subs + [c] -> matchListDBG (Just c) nocaps subs + where matchList' :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] + matchList' Nothing [] [] = [Map.empty] + matchList' Nothing [] _ = [] + matchList' (Just (Capture name)) [] subs = [Map.singleton name $ constr subs] + matchList' (Just node) [] subs = astMatch node (constr subs) + matchList' mcap (pat:pats) subs = + let firstmatches = concat $ mapDel (\s other -> map (,other) $ astMatch pat s) subs + processed = concat + $ map (\(ass,rest) -> + let replpats = map (replaceCaptures ass) pats + replmcap = fmap (replaceCaptures ass) mcap + in map (Map.union ass) $ matchListDBG replmcap replpats rest) + firstmatches + in {-trace ("firstmatches = "++show firstmatches) $ trace ("processed = "++show processed) $-} processed + + matchListDBG :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] + matchListDBG mcap pats subs = {-force $ trace ("\n<< "++show (mcap,pats,subs)++" >>\n") + $-} matchList' mcap pats subs + + +matchOrderedList :: [AST] -- ordered patterns + -> [AST] -- ordered subjects + -> [Map.Map String AST] -- list of possible capture assignments +matchOrderedList [] [] = [Map.empty] +matchOrderedList [] _ = [] +matchOrderedList _ [] = [] +matchOrderedList (pat:pats) (sub:subs) = + let opts = astMatch pat sub + newpatsopts = [(map (replaceCaptures opt) pats,opt) | opt <- opts] + -- ^ list of possible refined versions of the (rest of the) pattern list + in {-trace (show (pat:pats) ++ " -- " ++ show (sub:subs)) $ traceShow opts $-} + concat $ map (\(newpats,opt) -> map (Map.union opt) + $ matchOrderedList newpats subs) newpatsopts + + +replaceCaptures :: Map.Map String AST -> AST -> AST +replaceCaptures mp n = case n of + Number _ -> n + Variable _ -> n + Sum l -> Sum $ map (replaceCaptures mp) l + Product l -> Product $ map (replaceCaptures mp) l + Negative n2 -> Negative $ replaceCaptures mp n2 + Reciprocal n2 -> Reciprocal $ replaceCaptures mp n2 + Apply name n2 -> Apply name $ map (replaceCaptures mp) n2 + Capture name -> maybe n id $ Map.lookup name mp + CaptureTerm name -> maybe n id $ Map.lookup name mp + CaptureConstr name c -> maybe n id $ Map.lookup name mp + + +hasCaptures :: AST -> Bool +hasCaptures n = case n of + Number _ -> False + Variable _ -> False + Sum l -> any id [hasCaptures m | m <- l] + Product l -> any id [hasCaptures m | m <- l] + Negative m -> hasCaptures m + Reciprocal m -> hasCaptures m + Apply _ l -> any id [hasCaptures m | m <- l] + Capture _ -> True + CaptureTerm _ -> True + CaptureConstr _ _ -> True + + +assert :: Bool -> a -> a +assert = assertS "(no reason)" + +assertS :: String -> Bool -> a -> a +assertS _ True = id +assertS s False = error $ "Condition not satisfied in assert: " ++ s + + +mapDel :: (a -> [a] -> b) -> [a] -> [b] +mapDel _ [] = [] +mapDel f l = + let splits = zip l + $ map (\(a,b:bs) -> a++bs) + $ iterate (\(a,b:bs) -> (a++[b],bs)) ([],l) + in map (uncurry f) splits + + +-- some testing things +--pat = Sum [Number 1,Capture "x",Negative $ Capture "x"] +--sub = Sum [Number 4,Variable "a",Number 1,Negative $ Sum [Variable "a",Number 4]] + +--pat = Sum [Negative $ Capture "x"] +--sub = Sum [Negative $ Sum [Variable "a",Number 4]] + +--pat = Sum [Capture "x",Negative (Capture "x"),CaptureTerm "y",CaptureTerm "z"] +--sub = let x = Reciprocal (Number 7) in Sum [x,Negative x,Number 7,Number 8] + +--pat = Sum [CaptureTerm "x",CaptureTerm "y",Capture "rest",Negative $ Capture "rest"] +--sub = Sum [Number 1,Number 2,Negative $ Number 1,Variable "kaas",Negative $ Sum [Negative $ Number 1,Variable "kaas"]] + +pat = Sum [Product [Capture "x"],Product [Capture "x"]] +sub = Sum [Product [Number 1],Product [Number 1]] + +main = do + let res = astMatch pat sub + deepseq res $ putStrLn $ "\x1B[32m"++show res++"\x1B[0m" diff --git a/debug.hs b/debug.hs new file mode 100644 index 0000000..33ca1b0 --- /dev/null +++ b/debug.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE CPP #-} + +module Debug where + +#if 1 + +import qualified Debug.Trace as Trace + +trace :: String -> a -> a +trace = Trace.trace + +traceShow :: (Show a) => a -> b -> b +traceShow = Trace.traceShow + +#else + +trace :: String -> a -> a +trace = flip const + +traceShow :: a -> b -> b +traceShow = flip const + +#endif @@ -0,0 +1,31 @@ +module Main where + +import Control.Monad +import Data.Either +import Data.Maybe +import System.Console.Readline + +import Simplify +import Parser +import PrettyPrint + + +repl :: IO () +repl = do + mline <- readline "> " + case mline of + Nothing -> return () -- EOF + Just "" -> repl + Just line -> do + addHistory line + let eexpr = parseExpression line + either (putStrLn . ("Error: "++)) handleExpression eexpr + repl + where + handleExpression expr = do + print expr + let sim = simplify expr + print sim + +main :: IO () +main = repl diff --git a/parser.hs b/parser.hs new file mode 100644 index 0000000..6d8ed6d --- /dev/null +++ b/parser.hs @@ -0,0 +1,236 @@ +module Parser (parseExpression) where + +import Control.Applicative +import Control.Monad +import Data.Char +import Data.Maybe + +import AST +import Utility + + +parseExpression :: String -> Either String AST +parseExpression s = case parse pexpression s of + ((node,rest):_) -> case rest of + "" -> Right node + s -> Left $ "Cannot parse from '" ++ take 10 rest ++ "'" + _ -> Left "No valid parse" + + +newtype Parser a = Parser (String -> [(a,String)]) + +parse :: Parser a -> String -> [(a,String)] +parse (Parser p) = p + +instance Functor Parser where + fmap f p = Parser (\cs -> map (\(a,s) -> (f a,s)) $ parse p cs) + +instance Applicative Parser where + pure x = Parser (\cs -> [(x,cs)]) + (<*>) f p = Parser (\cs -> concat $ + map (\(a,s) -> parse (fmap a p) s) $ parse f cs) + +instance Monad Parser where + p >>= f = Parser (\cs -> concat $ + map (\(a,s) -> parse (f a) s) $ parse p cs) + +instance Alternative Parser where + empty = Parser (\_ -> []) + (<|>) p q = Parser (\cs -> parse p cs ++ parse q cs) + +instance MonadPlus Parser + + +--The deterministic choice operator: choose the first possibile parse (if +--available at all) from the results given by the two parsers. +--mplus is the non-deterministic choice operator; it would give all results. +mplus1 :: Parser a -> Parser a -> Parser a +mplus1 p q = Parser $ \cs -> case parse (mplus p q) cs of + [] -> [] + (x:_) -> [x] + +--(++) = mplus +(+++) = mplus1 + + +pitem :: Parser Char +pitem = Parser $ \s -> case s of + "" -> [] + (c:cs) -> [(c,cs)] + +psat :: (Char -> Bool) -> Parser Char +psat f = do + c <- pitem + if f c then return c else mzero + +--checks that the next char satisfies the predicate; does NOT consume characters +passert :: (Char -> Bool) -> Parser () +passert p = Parser $ \s -> case s of + "" -> [] + (c:_) -> if p c then [((),s)] else [] + +pchar :: Char -> Parser Char +pchar c = psat (==c) + +pstring :: String -> Parser String +pstring "" = return "" +pstring (c:cs) = do + pchar c + pstring cs + return (c:cs) + +pmany :: Parser a -> Parser [a] +pmany p = pmany1 p +++ return [] + +pmany1 :: Parser a -> Parser [a] +pmany1 p = do + a <- p + as <- pmany p + return (a:as) + +pinteger :: Parser Int +pinteger = do + s <- pmany $ psat isDigit + return $ read s + +pdouble :: Parser Double +pdouble = Parser reads + +pquotstring :: Parser String +pquotstring = Parser reads + +poptws :: Parser String +poptws = Parser $ pure . span isSpace + +pws :: Parser String +pws = Parser $ \s -> case span isSpace s of + ("",_) -> [] + tup@(_,_) -> [tup] + +pword :: Parser String +pword = do + c <- psat $ options [isAlpha,(=='_')] + cs <- pmany $ psat $ options [isAlpha,isDigit,(=='_')] + return (c:cs) + + +pnumber :: Parser AST +pnumber = liftM Number pdouble + +pvariable :: Parser AST +pvariable = liftM Variable $ pstring "PI" +++ (liftM pure (psat isAlpha)) + +pinfixoperator :: (Char,Char) -- +/- symbols + -> Parser AST -- term parser + -> ([AST] -> AST) -- Sum constructor + -> (AST -> AST) -- Negative constructor + -> Bool -- whether the plus sign is optional + -> Bool -- whether a negative sign cannot follow after a term + -> Parser AST -- Resulting parser +pinfixoperator (plus,minus) pterm sumconstr negconstr plusopt noneg = do + term <- pterm + pmoreterms term +++ return (sumconstr [term]) + where + pmoreterms term = if plusopt + then pmoretermsplus term +++ pmoretermsminus term +++ pmoretermsnothing term + else pmoretermsplus term +++ pmoretermsminus term + + pmoretermsplus term = do + poptws + pchar plus + poptws + nextterm <- pterm + let thissum = sumconstr [term,nextterm] + pmoreterms thissum +++ return thissum + pmoretermsminus term = do + poptws + pchar minus + poptws + nextterm <- pterm + let thissum = sumconstr [term,negconstr nextterm] + pmoreterms thissum +++ return thissum + pmoretermsnothing term = do + poptws + if noneg then passert (/='-') else return () + nextterm <- pterm + let thissum = sumconstr [term,nextterm] + pmoreterms thissum +++ return thissum + +psum :: Parser AST +psum = pinfixoperator ('+','-') pproduct Sum Negative False False + +pproduct :: Parser AST +pproduct = pinfixoperator ('*','/') pfactor Product Reciprocal True True + +pfactor :: Parser AST +pfactor = pnegative +++ pfactornoneg +++ pcapture +++ pcaptureterm + +pnegative :: Parser AST +pnegative = do {pchar '-'; poptws; f <- pfactor; return $ Negative f} +++ pfactornoneg + +pfactornoneg :: Parser AST +pfactornoneg = do + fact <- pnumber +++ pparenthetical +++ pfunctioncall +++ pvariable + ppower fact +++ pfactorial fact +++ return fact + where + ppower fact = do + poptws + pchar '^' + poptws + fact2 <- pfactornoneg + return $ Apply "pow" [fact,fact2] + pfactorial fact = do + poptws + pchar '!' + return $ Apply "fact" [fact] + + +pparenthetical :: Parser AST +pparenthetical = do + pchar '(' + poptws + sum <- psum + poptws + pchar ')' + return sum + +pfunctioncall :: Parser AST +pfunctioncall = do + name <- pword + poptws + pchar '(' + poptws + args <- parglist + poptws + pchar ')' + return $ Apply name args + where + parglist = do + arg <- parg + poptws + pmoreargs arg +++ return [arg] + pmoreargs arg = do + pchar ',' + poptws + args <- parglist + return (arg:args) + parg = pexpression + +pcapture :: Parser AST +pcapture = do + pchar '[' + name <- pmany1 $ psat (/=']') + pchar ']' + return $ Capture name + +pcaptureterm :: Parser AST +pcaptureterm = do + pchar '[' + pchar '[' + name <- pmany1 $ psat (/=']') + pchar ']' + pchar ']' + return $ CaptureTerm name + +pexpression :: Parser AST +pexpression = psum diff --git a/prettyprint.hs b/prettyprint.hs new file mode 100644 index 0000000..45ca6ac --- /dev/null +++ b/prettyprint.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE FlexibleInstances #-} + +module PrettyPrint where + +class PrettyPrint a where + prettyPrint :: a -> String + -- a = a + +instance PrettyPrint String where prettyPrint = id +instance PrettyPrint Double where prettyPrint = show +instance PrettyPrint Int where prettyPrint = show diff --git a/simplify.hs b/simplify.hs new file mode 100644 index 0000000..cb3e5af --- /dev/null +++ b/simplify.hs @@ -0,0 +1,172 @@ +{-# LANGUAGE CPP #-} + +module Simplify (simplify) where + +import Data.List +import qualified Data.Map.Strict as Map + +import AST +import Utility + +import Debug +import PrettyPrint + +tracex :: (Show a) => String -> a -> a +tracex s x = trace (s ++ ": " ++ show x) x + +tracexp :: (PrettyPrint a) => String -> a -> a +tracexp s x = trace (s ++ ": " ++ prettyPrint x) x + + +simplify :: AST -> AST +simplify = tracex "last canonicaliseOrder" . canonicaliseOrder + . (fixpoint $ tracex "applyPatterns " . applyPatterns + . tracex "flattenSums " . flattenSums + -- . tracex "collectLikeTerms " . collectLikeTerms + . tracex "canonicaliseOrder" . canonicaliseOrder + . tracex "foldNumbers " . foldNumbers) + . tracex "first flattenSums" . flattenSums + + +flattenSums :: AST -> AST +flattenSums node = case node of + (Negative n) -> Negative $ flattenSums n + (Reciprocal n) -> Reciprocal $ flattenSums n + (Apply name args) -> Apply name $ map flattenSums args + (Sum args) -> case length args of + 0 -> Number 0 + 1 -> flattenSums $ args !! 0 + otherwise -> Sum $ concat $ map (listify . flattenSums) args + where + listify (Sum args) = args + listify node = [node] + (Product args) -> case length args of + 0 -> Number 1 + 1 -> flattenSums $ args !! 0 + otherwise -> Product $ concat $ map (listify . flattenSums) args + where + listify (Product args) = args + listify node = [node] + _ -> node + +foldNumbers :: AST -> AST +foldNumbers node = case node of + (Negative n) -> let fn = foldNumbers n in case fn of + (Number x) -> Number (-x) + (Negative n2) -> n2 + (Product args) -> Product $ Number (-1) : args + _ -> Negative $ fn + (Reciprocal n) -> let fn = foldNumbers n in case fn of + (Number x) -> Number (1/x) + (Negative n) -> Negative $ Reciprocal fn + (Reciprocal n2) -> n2 + _ -> Reciprocal $ fn + (Apply name args) -> Apply name $ map foldNumbers args + (Sum args) -> Sum $ dofoldnums sum args 0 + (Product args) -> dofoldnegsToProd $ dofoldnums product args 1 + _ -> node + where + dofoldnums func args zerovalue = + let foldedArgs = map foldNumbers args + (nums,notnums) = partition astIsNumber foldedArgs + foldvalue = func $ map (\(Number n) -> n) nums + in case length nums of + x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums + otherwise -> foldedArgs + dofoldnegsToProd args = + let foldedArgs = map foldNumbers args + (negs,notnegs) = partition isneg foldedArgs + isneg (Negative _) = True + isneg (Number n) = n < 0 + isneg _ = False + unneg (Negative n) = n + unneg (Number n) = Number $ abs n + unneg n = n + unnegged = map unneg negs ++ notnegs + in case length negs of + x | x < 2 -> Product args + | even x -> Product unnegged + | odd x -> Product $ Number (-1) : unnegged + +-- collectLikeTerms :: AST -> AST +-- collectLikeTerms node = case node of +-- (Reciprocal n) -> Apply "pow" [n,Number $ -1] +-- (Product args) -> +-- let ispow (Apply "pow" _) = True +-- ispow _ = False +-- (pows,nopows) = partition ispow $ map collectLikeTerms args +-- groups = groupBy (\(Apply _ [x,_]) (Apply _ [y,_]) -> x == y) pows +-- baseof (Apply _ [x,_]) = x +-- expof (Apply _ [_,x]) = x +-- collectGroup l = Apply "pow" [baseof (l!!0),Sum $ map expof l] +-- in Product $ map collectGroup groups ++ nopows +-- (Sum args) -> +-- let isnumterm (Product (Number _:_)) = True +-- isnumterm _ = False +-- (numterms,nonumterms) = partition isnumterm $ map collectLikeTerms args +-- groups = groupBy (\(Product (Number _:xs)) (Product (Number _:ys)) +-- -> astMatchSimple (Product xs) (Product ys)) +-- numterms +-- numof (Product (n:_)) = n +-- restof (Product (_:rest)) = rest +-- collectGroup l = +-- if length l == 1 +-- then l!!0 +-- else Product $ Sum (map numof l) : restof (l!!0) +-- in Sum $ map collectGroup groups ++ nonumterms +-- _ -> node + +canonicaliseOrder :: AST -> AST +canonicaliseOrder node = case node of + (Number _) -> node + (Variable _) -> node + (Sum args) -> Sum $ sort args + (Product args) -> Product $ sort args + (Negative n) -> Negative $ canonicaliseOrder n + (Reciprocal n) -> Reciprocal $ canonicaliseOrder n + (Apply name args) -> Apply name $ map canonicaliseOrder args + (Capture _) -> node + (CaptureTerm _) -> node + (CaptureConstr _ _) -> node + + +patterndb :: [(AST,AST)] +patterndb = [ + (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"], -- x + x + [rest] -> 2*x + [rest] + Sum [Product [Number 2,Capture "x"],Capture "rest"]), + + (Sum [CaptureTerm "x", -- x + n*x + [rest] -> (1+n)*x + [rest] + Product [CaptureConstr "n" (Number undefined),Capture "x"], + Capture "rest"], + Sum [Product [Sum [Number 1,Capture "n"], + Capture "x"], + Capture "rest"]), + + (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"], -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest] + Product [CaptureConstr "n2" (Number undefined),Capture "x"], + Capture "rest"], + + Sum [Product [Sum [Capture "n1",Capture "n2"], + Capture "x"], + Capture "rest"]) + ] + +astChildMap :: AST -> (AST -> AST) -> AST +astChildMap node f = case node of + (Number _) -> node + (Variable _) -> node + (Sum args) -> Sum $ map f args + (Product args) -> Product $ map f args + (Negative n) -> Negative $ f n + (Reciprocal n) -> Reciprocal $ f n + (Apply name args) -> Apply name $ map f args + (Capture _) -> node + (CaptureTerm _) -> node + (CaptureConstr _ _) -> node + +applyPatterns :: AST -> AST +applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb + in if null matches + then astChildMap node applyPatterns + else let ((capdict:_),repl) = head matches -- TODO: don't take the first option of the first match, but use them all + in replaceCaptures capdict repl diff --git a/utility.hs b/utility.hs new file mode 100644 index 0000000..c442f80 --- /dev/null +++ b/utility.hs @@ -0,0 +1,20 @@ +module Utility where + +import Data.List + +--if any of the predicates returns true, options returns true +options :: [a -> Bool] -> a -> Bool +options l x = any id [f x | f <- l] + +fixpoint :: (Eq a) => (a -> a) -> a -> a +fixpoint f x = let fx = f x in if fx == x then x else fixpoint f fx + +setcompareBy :: (a -> a -> Bool) -> [a] -> [a] -> Bool +setcompareBy _ [] [] = True +setcompareBy _ [] _ = False +setcompareBy _ _ [] = False +setcompareBy p a@(x:xs) b = length a == length b && setcompareBy p xs (deleteBy p x b) + +deleteIndex :: Int -> [a] -> [a] +deleteIndex 0 (_:xs) = xs +deleteIndex i (x:xs) | i > 0 = x : deleteIndex (i-1) xs |