diff options
| author | tomsmeding <tom.smeding@gmail.com> | 2016-06-16 23:24:47 +0200 | 
|---|---|---|
| committer | tomsmeding <tom.smeding@gmail.com> | 2016-06-16 23:24:47 +0200 | 
| commit | dd3db844dd49451f28d044cd1d2fd71430d73686 (patch) | |
| tree | cd76d7ad6efbf2d2e4760695d39cb48bb479a936 | |
Initial
| -rw-r--r-- | .gitignore | 4 | ||||
| -rw-r--r-- | Makefile | 16 | ||||
| -rw-r--r-- | ast.hs | 283 | ||||
| -rw-r--r-- | debug.hs | 23 | ||||
| -rw-r--r-- | main.hs | 31 | ||||
| -rw-r--r-- | parser.hs | 236 | ||||
| -rw-r--r-- | prettyprint.hs | 11 | ||||
| -rw-r--r-- | simplify.hs | 172 | ||||
| -rw-r--r-- | utility.hs | 20 | 
9 files changed, 796 insertions, 0 deletions
| diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c67b73 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.o +*.hi +*.hs[0-9] +main diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6d4ca8f --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +BIN = main + +hs_files = $(wildcard *.hs) + +.PHONY: all clean remake + +all: $(BIN) + +clean: +	rm -f *.hi *.o $(BIN) + +remake: clean all + + +$(BIN): $(hs_files) +	ghc -O3 -o $(BIN) $^ @@ -0,0 +1,283 @@ +{-# LANGUAGE TupleSections, BangPatterns, DeriveDataTypeable #-} + +module AST where + +import qualified Data.Map.Strict as Map +import Data.List +import Data.Data +import Data.Typeable +import Control.DeepSeq + +import PrettyPrint +import Debug + + +data AST = Number Double +         | Variable String +         | Sum [AST] +         | Product [AST] +         | Negative AST +         | Reciprocal AST +         | Apply String [AST] +         -- The following only in patterns: +         | Capture String +         | CaptureTerm String +         | CaptureConstr String AST  -- AST argument only for constructor; only WHNF +         deriving (Eq,Typeable,Data) + +instance PrettyPrint AST where +    prettyPrint (Number n) = show n + +    prettyPrint (Variable n) = n + +    prettyPrint (Sum []) = "(+)" +    prettyPrint (Sum args) = intercalate " + " $ map prettyPrint args + +    prettyPrint (Product []) = "(*)" +    prettyPrint (Product args) = intercalate "*" $ map gopp args +        where gopp s@(Sum _) = '(' : prettyPrint s ++ ")" +              gopp n = prettyPrint n + +    prettyPrint (Negative n) = '-' : case n of +        s@(Sum _) -> '(' : prettyPrint s ++ ")" +        n -> prettyPrint n + +    prettyPrint (Reciprocal n) = "1/" ++ case n of +        s@(Sum _) -> '(' : prettyPrint s ++ ")" +        s@(Product _) -> '(' : prettyPrint s ++ ")" +        s@(Reciprocal _) -> '(' : prettyPrint s ++ ")" +        n -> prettyPrint n + +    prettyPrint (Apply name args) = name ++ "(" ++ intercalate "," (map prettyPrint args) ++ ")" + +    prettyPrint (Capture name) = '[' : name ++ "]" + +    prettyPrint (CaptureTerm name) = '[' : '[' : name ++ "]]" + +    prettyPrint (CaptureConstr name c) = '[' : name ++ ":" ++ showConstr (toConstr c) ++ "]" + +instance Show AST where +    show = prettyPrint + +instance NFData AST where +    rnf (Number !_) = () +    rnf (Variable !_) = () +    rnf (Sum l) = seq (length $ map rnf l) () +    rnf (Product l) = seq (length $ map rnf l) () +    rnf (Negative n) = rnf n +    rnf (Reciprocal n) = rnf n +    rnf (Apply !_ l) = seq (length $ map rnf l) () +    rnf (Capture !_) = () +    rnf (CaptureTerm !_) = () +    rnf (CaptureConstr !_ !_) = ()  -- explicitly not deepseq'ing the ast node + +instance Ord AST where +    compare (Number _) (Number _) = EQ +    compare (Variable _) (Variable _) = EQ +    compare (Sum _) (Sum _) = EQ +    compare (Product _) (Product _) = EQ +    compare (Negative _) (Negative _) = EQ +    compare (Reciprocal _) (Reciprocal _) = EQ +    compare (Apply _ _) (Apply _ _) = EQ +    compare (Capture _) (Capture _) = EQ +    compare (CaptureTerm _) (CaptureTerm _) = EQ +    compare (CaptureConstr _ _) (CaptureConstr _ _) = EQ + +    compare (Capture _) _ = LT  -- Unbounded captures first for efficient +    compare _ (Capture _) = GT  -- extraction with span isCapture +    compare (Number _) _ = LT +    compare _ (Number _) = GT +    compare (Variable _) _ = LT +    compare _ (Variable _) = GT +    compare (Sum _) _ = LT +    compare _ (Sum _) = GT +    compare (Product _) _ = LT +    compare _ (Product _) = GT +    compare (Negative _) _ = LT +    compare _ (Negative _) = GT +    compare (Reciprocal _) _ = LT +    compare _ (Reciprocal _) = GT +    compare (Apply _ _) _ = LT +    compare _ (Apply _ _) = GT +    compare (CaptureTerm _) _ = LT +    compare _ (CaptureTerm _) = GT +    -- compare (CaptureConstr _ _) _ = LT +    -- compare _ (CaptureConstr _ _) = GT + + +astIsNumber :: AST -> Bool +astIsNumber (Number _) = True +astIsNumber _ = False + +astIsCapture :: AST -> Bool +astIsCapture (Capture _) = True +astIsCapture _ = False + + +astMatchSimple :: AST -> AST -> Bool +astMatchSimple pat sub = let res = {-(\x -> trace (" !! RESULT: " ++ show x ++ " !! ") x) $-} astMatch pat sub +    in if null res +        then False +        else any Map.null res + + +astMatch :: AST                   -- pattern +         -> AST                   -- subject +         -> [Map.Map String AST]  -- list of possible capture assignments +astMatch pat sub = assertS "No captures in astMatch subject" (not $ hasCaptures sub) $ +    case pat of +        Number x -> case sub of +            Number y | x == y -> [Map.empty] +            _ -> [] + +        Variable name -> case sub of +            Variable name2 | name == name2 -> [Map.empty] +            _ -> [] + +        Sum [term] -> case sub of +            Sum l2 -> matchList Sum [term] l2 +            s -> astMatch term s + +        Sum l -> case sub of +            Sum l2 -> matchList Sum l l2 +            _ -> [] + +        Product [term] -> case sub of +            Product l2 -> matchList Product [term] l2 +            s -> astMatch term s + +        Product l -> case sub of +            Product l2 -> matchList Product l l2 +            _ -> [] + +        Negative n -> case sub of +            Negative n2 -> astMatch n n2 +            _ -> [] + +        Reciprocal n -> case sub of +            Reciprocal n2 -> astMatch n n2 +            _ -> [] + +        Apply name l -> case sub of +            Apply name l2 -> matchOrderedList l l2 +            _ -> [] + +        Capture name -> [Map.singleton name sub] + +        CaptureTerm name -> [Map.singleton name sub] + +        CaptureConstr name constr -> +            if toConstr sub == toConstr constr +                then [Map.singleton name sub] +                else [] + + +matchList :: ([AST] -> AST)        -- AST constructor for this list (for insertion in capture) +          -> [AST]                 -- unordered patterns +          -> [AST]                 -- unordered subjects +          -> [Map.Map String AST]  -- list of possible capture assignments +matchList constr pats subs = +    let ordered = sort pats +        (captures,nocaps) = span astIsCapture ordered +    in assertS "At most one capture in sum/product" (length captures <= 1) $ case captures of +        [] -> matchListDBG Nothing nocaps subs +        [c] -> matchListDBG (Just c) nocaps subs +    where matchList' :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] +          matchList' Nothing [] [] = [Map.empty] +          matchList' Nothing [] _ = [] +          matchList' (Just (Capture name)) [] subs = [Map.singleton name $ constr subs] +          matchList' (Just node) [] subs = astMatch node (constr subs) +          matchList' mcap (pat:pats) subs = +              let firstmatches = concat $ mapDel (\s other -> map (,other) $ astMatch pat s) subs +                  processed = concat +                              $ map (\(ass,rest) -> +                                         let replpats = map (replaceCaptures ass) pats +                                             replmcap = fmap (replaceCaptures ass) mcap +                                         in map (Map.union ass) $ matchListDBG replmcap replpats rest) +                                    firstmatches +              in {-trace ("firstmatches = "++show firstmatches) $ trace ("processed = "++show processed) $-} processed + +          matchListDBG :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] +          matchListDBG mcap pats subs = {-force $ trace ("\n<< "++show (mcap,pats,subs)++" >>\n") +                                        $-} matchList' mcap pats subs + + +matchOrderedList :: [AST]                 -- ordered patterns +                 -> [AST]                 -- ordered subjects +                 -> [Map.Map String AST]  -- list of possible capture assignments +matchOrderedList [] [] = [Map.empty] +matchOrderedList [] _ = [] +matchOrderedList _ [] = [] +matchOrderedList (pat:pats) (sub:subs) = +    let opts = astMatch pat sub +        newpatsopts = [(map (replaceCaptures opt) pats,opt) | opt <- opts] +            -- ^ list of possible refined versions of the (rest of the) pattern list +    in {-trace (show (pat:pats) ++ " -- " ++ show (sub:subs)) $ traceShow opts $-} +       concat $ map (\(newpats,opt) -> map (Map.union opt) +              $ matchOrderedList newpats subs) newpatsopts + + +replaceCaptures :: Map.Map String AST -> AST -> AST +replaceCaptures mp n = case n of +    Number _ -> n +    Variable _ -> n +    Sum l -> Sum $ map (replaceCaptures mp) l +    Product l -> Product $ map (replaceCaptures mp) l +    Negative n2 -> Negative $ replaceCaptures mp n2 +    Reciprocal n2 -> Reciprocal $ replaceCaptures mp n2 +    Apply name n2 -> Apply name $ map (replaceCaptures mp) n2 +    Capture name -> maybe n id $ Map.lookup name mp +    CaptureTerm name -> maybe n id $ Map.lookup name mp +    CaptureConstr name c -> maybe n id $ Map.lookup name mp + + +hasCaptures :: AST -> Bool +hasCaptures n = case n of +    Number _ -> False +    Variable _ -> False +    Sum l -> any id [hasCaptures m | m <- l] +    Product l -> any id [hasCaptures m | m <- l] +    Negative m -> hasCaptures m +    Reciprocal m -> hasCaptures m +    Apply _ l -> any id [hasCaptures m | m <- l] +    Capture _ -> True +    CaptureTerm _ -> True +    CaptureConstr _ _ -> True + + +assert :: Bool -> a -> a +assert = assertS "(no reason)" + +assertS :: String -> Bool -> a -> a +assertS _ True = id +assertS s False = error $ "Condition not satisfied in assert: " ++ s + + +mapDel :: (a -> [a] -> b) -> [a] -> [b] +mapDel _ [] = [] +mapDel f l = +    let splits = zip l +                 $ map (\(a,b:bs) -> a++bs) +                 $ iterate (\(a,b:bs) -> (a++[b],bs)) ([],l) +    in map (uncurry f) splits + + +-- some testing things +--pat = Sum [Number 1,Capture "x",Negative $ Capture "x"] +--sub = Sum [Number 4,Variable "a",Number 1,Negative $ Sum [Variable "a",Number 4]] + +--pat = Sum [Negative $ Capture "x"] +--sub = Sum [Negative $ Sum [Variable "a",Number 4]] + +--pat = Sum [Capture "x",Negative (Capture "x"),CaptureTerm "y",CaptureTerm "z"] +--sub = let x = Reciprocal (Number 7) in Sum [x,Negative x,Number 7,Number 8] + +--pat = Sum [CaptureTerm "x",CaptureTerm "y",Capture "rest",Negative $ Capture "rest"] +--sub = Sum [Number 1,Number 2,Negative $ Number 1,Variable "kaas",Negative $ Sum [Negative $ Number 1,Variable "kaas"]] + +pat = Sum [Product [Capture "x"],Product [Capture "x"]] +sub = Sum [Product [Number 1],Product [Number 1]] + +main = do +    let res = astMatch pat sub +    deepseq res $ putStrLn $ "\x1B[32m"++show res++"\x1B[0m" diff --git a/debug.hs b/debug.hs new file mode 100644 index 0000000..33ca1b0 --- /dev/null +++ b/debug.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE CPP #-} + +module Debug where + +#if 1 + +import qualified Debug.Trace as Trace + +trace :: String -> a -> a +trace = Trace.trace + +traceShow :: (Show a) => a -> b -> b +traceShow = Trace.traceShow + +#else + +trace :: String -> a -> a +trace = flip const + +traceShow :: a -> b -> b +traceShow = flip const + +#endif @@ -0,0 +1,31 @@ +module Main where + +import Control.Monad +import Data.Either +import Data.Maybe +import System.Console.Readline + +import Simplify +import Parser +import PrettyPrint + + +repl :: IO () +repl = do +    mline <- readline "> " +    case mline of +        Nothing -> return ()  -- EOF +        Just "" -> repl +        Just line -> do +            addHistory line +            let eexpr = parseExpression line +            either (putStrLn . ("Error: "++)) handleExpression eexpr +            repl +    where +        handleExpression expr = do +            print expr +            let sim = simplify expr +            print sim + +main :: IO () +main = repl diff --git a/parser.hs b/parser.hs new file mode 100644 index 0000000..6d8ed6d --- /dev/null +++ b/parser.hs @@ -0,0 +1,236 @@ +module Parser (parseExpression) where + +import Control.Applicative +import Control.Monad +import Data.Char +import Data.Maybe + +import AST +import Utility + + +parseExpression :: String -> Either String AST +parseExpression s = case parse pexpression s of +    ((node,rest):_) -> case rest of +        "" -> Right node +        s -> Left $ "Cannot parse from '" ++ take 10 rest ++ "'" +    _ -> Left "No valid parse" + + +newtype Parser a = Parser (String -> [(a,String)]) + +parse :: Parser a -> String -> [(a,String)] +parse (Parser p) = p + +instance Functor Parser where +    fmap f p = Parser (\cs -> map (\(a,s) -> (f a,s)) $ parse p cs) + +instance Applicative Parser where +    pure x = Parser (\cs -> [(x,cs)]) +    (<*>) f p = Parser (\cs -> concat $ +        map (\(a,s) -> parse (fmap a p) s) $ parse f cs) + +instance Monad Parser where +    p >>= f = Parser (\cs -> concat $ +        map (\(a,s) -> parse (f a) s) $ parse p cs) + +instance Alternative Parser where +    empty = Parser (\_ -> []) +    (<|>) p q = Parser (\cs -> parse p cs ++ parse q cs) + +instance MonadPlus Parser + + +--The deterministic choice operator: choose the first possibile parse (if +--available at all) from the results given by the two parsers. +--mplus is the non-deterministic choice operator; it would give all results. +mplus1 :: Parser a -> Parser a -> Parser a +mplus1 p q = Parser $ \cs -> case parse (mplus p q) cs of +    [] -> [] +    (x:_) -> [x] + +--(++) = mplus +(+++) = mplus1 + + +pitem :: Parser Char +pitem = Parser $ \s -> case s of +    "" -> [] +    (c:cs) -> [(c,cs)] + +psat :: (Char -> Bool) -> Parser Char +psat f = do +    c <- pitem +    if f c then return c else mzero + +--checks that the next char satisfies the predicate; does NOT consume characters +passert :: (Char -> Bool) -> Parser () +passert p = Parser $ \s -> case s of +    "" -> [] +    (c:_) -> if p c then [((),s)] else [] + +pchar :: Char -> Parser Char +pchar c = psat (==c) + +pstring :: String -> Parser String +pstring "" = return "" +pstring (c:cs) = do +    pchar c +    pstring cs +    return (c:cs) + +pmany :: Parser a -> Parser [a] +pmany p = pmany1 p +++ return [] + +pmany1 :: Parser a -> Parser [a] +pmany1 p = do +    a <- p +    as <- pmany p +    return (a:as) + +pinteger :: Parser Int +pinteger = do +    s <- pmany $ psat isDigit +    return $ read s + +pdouble :: Parser Double +pdouble = Parser reads + +pquotstring :: Parser String +pquotstring = Parser reads + +poptws :: Parser String +poptws = Parser $ pure . span isSpace + +pws :: Parser String +pws = Parser $ \s -> case span isSpace s of +    ("",_) -> [] +    tup@(_,_) -> [tup] + +pword :: Parser String +pword = do +    c <- psat $ options [isAlpha,(=='_')] +    cs <- pmany $ psat $ options [isAlpha,isDigit,(=='_')] +    return (c:cs) + + +pnumber :: Parser AST +pnumber = liftM Number pdouble + +pvariable :: Parser AST +pvariable = liftM Variable $ pstring "PI" +++ (liftM pure (psat isAlpha)) + +pinfixoperator :: (Char,Char)  -- +/- symbols +               -> Parser AST  -- term parser +               -> ([AST] -> AST)  -- Sum constructor +               -> (AST -> AST)  -- Negative constructor +               -> Bool  -- whether the plus sign is optional +               -> Bool  -- whether a negative sign cannot follow after a term +               -> Parser AST  -- Resulting parser +pinfixoperator (plus,minus) pterm sumconstr negconstr plusopt noneg = do +    term <- pterm +    pmoreterms term +++ return (sumconstr [term]) +    where +        pmoreterms term = if plusopt +            then pmoretermsplus term +++ pmoretermsminus term +++ pmoretermsnothing term +            else pmoretermsplus term +++ pmoretermsminus term + +        pmoretermsplus term = do +            poptws +            pchar plus +            poptws +            nextterm <- pterm +            let thissum = sumconstr [term,nextterm] +            pmoreterms thissum +++ return thissum +        pmoretermsminus term = do +            poptws +            pchar minus +            poptws +            nextterm <- pterm +            let thissum = sumconstr [term,negconstr nextterm] +            pmoreterms thissum +++ return thissum +        pmoretermsnothing term = do +            poptws +            if noneg then passert (/='-') else return () +            nextterm <- pterm +            let thissum = sumconstr [term,nextterm] +            pmoreterms thissum +++ return thissum + +psum :: Parser AST +psum = pinfixoperator ('+','-') pproduct Sum Negative False False + +pproduct :: Parser AST +pproduct = pinfixoperator ('*','/') pfactor Product Reciprocal True True + +pfactor :: Parser AST +pfactor = pnegative +++ pfactornoneg +++ pcapture +++ pcaptureterm + +pnegative :: Parser AST +pnegative = do {pchar '-'; poptws; f <- pfactor; return $ Negative f} +++ pfactornoneg + +pfactornoneg :: Parser AST +pfactornoneg = do +    fact <- pnumber +++ pparenthetical +++ pfunctioncall +++ pvariable +    ppower fact +++ pfactorial fact +++ return fact +    where +        ppower fact = do +            poptws +            pchar '^' +            poptws +            fact2 <- pfactornoneg +            return $ Apply "pow" [fact,fact2] +        pfactorial fact = do +            poptws +            pchar '!' +            return $ Apply "fact" [fact] + + +pparenthetical :: Parser AST +pparenthetical = do +    pchar '(' +    poptws +    sum <- psum +    poptws +    pchar ')' +    return sum + +pfunctioncall :: Parser AST +pfunctioncall = do +    name <- pword +    poptws +    pchar '(' +    poptws +    args <- parglist +    poptws +    pchar ')' +    return $ Apply name args +    where +        parglist = do +            arg <- parg +            poptws +            pmoreargs arg +++ return [arg] +        pmoreargs arg = do +            pchar ',' +            poptws +            args <- parglist +            return (arg:args) +        parg = pexpression + +pcapture :: Parser AST +pcapture = do +    pchar '[' +    name <- pmany1 $ psat (/=']') +    pchar ']' +    return $ Capture name + +pcaptureterm :: Parser AST +pcaptureterm = do +    pchar '[' +    pchar '[' +    name <- pmany1 $ psat (/=']') +    pchar ']' +    pchar ']' +    return $ CaptureTerm name + +pexpression :: Parser AST +pexpression = psum diff --git a/prettyprint.hs b/prettyprint.hs new file mode 100644 index 0000000..45ca6ac --- /dev/null +++ b/prettyprint.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE FlexibleInstances #-} + +module PrettyPrint where + +class PrettyPrint a where +    prettyPrint :: a -> String +    -- a = a + +instance PrettyPrint String where prettyPrint = id +instance PrettyPrint Double where prettyPrint = show +instance PrettyPrint Int where prettyPrint = show diff --git a/simplify.hs b/simplify.hs new file mode 100644 index 0000000..cb3e5af --- /dev/null +++ b/simplify.hs @@ -0,0 +1,172 @@ +{-# LANGUAGE CPP #-} + +module Simplify (simplify) where + +import Data.List +import qualified Data.Map.Strict as Map + +import AST +import Utility + +import Debug +import PrettyPrint + +tracex :: (Show a) => String -> a -> a +tracex s x = trace (s ++ ": " ++ show x) x + +tracexp :: (PrettyPrint a) => String -> a -> a +tracexp s x = trace (s ++ ": " ++ prettyPrint x) x + + +simplify :: AST -> AST +simplify = tracex "last canonicaliseOrder" . canonicaliseOrder +           . (fixpoint $ tracex "applyPatterns    " . applyPatterns +                       . tracex "flattenSums      " . flattenSums +                       -- . tracex "collectLikeTerms " . collectLikeTerms +                       . tracex "canonicaliseOrder" . canonicaliseOrder +                       . tracex "foldNumbers      " . foldNumbers) +           . tracex "first flattenSums" . flattenSums + + +flattenSums :: AST -> AST +flattenSums node = case node of +    (Negative n) -> Negative $ flattenSums n +    (Reciprocal n) -> Reciprocal $ flattenSums n +    (Apply name args) -> Apply name $ map flattenSums args +    (Sum args) -> case length args of +        0 -> Number 0 +        1 -> flattenSums $ args !! 0 +        otherwise -> Sum $ concat $ map (listify . flattenSums) args +        where +            listify (Sum args) = args +            listify node = [node] +    (Product args) -> case length args of +        0 -> Number 1 +        1 -> flattenSums $ args !! 0 +        otherwise -> Product $ concat $ map (listify . flattenSums) args +        where +            listify (Product args) = args +            listify node = [node] +    _ -> node + +foldNumbers :: AST -> AST +foldNumbers node = case node of +    (Negative n) -> let fn = foldNumbers n in case fn of +        (Number x) -> Number (-x) +        (Negative n2) -> n2 +        (Product args) -> Product $ Number (-1) : args +        _ -> Negative $ fn +    (Reciprocal n) -> let fn = foldNumbers n in case fn of +        (Number x) -> Number (1/x) +        (Negative n) -> Negative $ Reciprocal fn +        (Reciprocal n2) -> n2 +        _ -> Reciprocal $ fn +    (Apply name args) -> Apply name $ map foldNumbers args +    (Sum args) -> Sum $ dofoldnums sum args 0 +    (Product args) -> dofoldnegsToProd $ dofoldnums product args 1 +    _ -> node +    where +        dofoldnums func args zerovalue = +            let foldedArgs = map foldNumbers args +                (nums,notnums) = partition astIsNumber foldedArgs +                foldvalue = func $ map (\(Number n) -> n) nums +            in case length nums of +                x | x >= 1 -> if foldvalue == zerovalue then notnums else Number foldvalue : notnums +                otherwise -> foldedArgs +        dofoldnegsToProd args = +            let foldedArgs = map foldNumbers args +                (negs,notnegs) = partition isneg foldedArgs +                isneg (Negative _) = True +                isneg (Number n) = n < 0 +                isneg _ = False +                unneg (Negative n) = n +                unneg (Number n) = Number $ abs n +                unneg n = n +                unnegged = map unneg negs ++ notnegs +            in case length negs of +                x | x < 2 -> Product args +                  | even x -> Product unnegged +                  | odd x -> Product $ Number (-1) : unnegged + +-- collectLikeTerms :: AST -> AST +-- collectLikeTerms node = case node of +--     (Reciprocal n) -> Apply "pow" [n,Number $ -1] +--     (Product args) ->  +--         let ispow (Apply "pow" _) = True +--             ispow _ = False +--             (pows,nopows) = partition ispow $ map collectLikeTerms args +--             groups = groupBy (\(Apply _ [x,_]) (Apply _ [y,_]) -> x == y) pows +--             baseof (Apply _ [x,_]) = x +--             expof (Apply _ [_,x]) = x +--             collectGroup l = Apply "pow" [baseof (l!!0),Sum $ map expof l] +--         in Product $ map collectGroup groups ++ nopows +--     (Sum args) -> +--         let isnumterm (Product (Number _:_)) = True +--             isnumterm _ = False +--             (numterms,nonumterms) = partition isnumterm $ map collectLikeTerms args +--             groups = groupBy (\(Product (Number _:xs)) (Product (Number _:ys)) +--                               -> astMatchSimple (Product xs) (Product ys)) +--                          numterms +--             numof (Product (n:_)) = n +--             restof (Product (_:rest)) = rest +--             collectGroup l = +--                 if length l == 1 +--                     then l!!0 +--                     else Product $ Sum (map numof l) : restof (l!!0)  +--         in Sum $ map collectGroup groups ++ nonumterms +--     _ -> node + +canonicaliseOrder :: AST -> AST +canonicaliseOrder node = case node of +    (Number _) -> node +    (Variable _) -> node +    (Sum args) -> Sum $ sort args +    (Product args) -> Product $ sort args +    (Negative n) -> Negative $ canonicaliseOrder n +    (Reciprocal n) -> Reciprocal $ canonicaliseOrder n +    (Apply name args) -> Apply name $ map canonicaliseOrder args +    (Capture _) -> node +    (CaptureTerm _) -> node +    (CaptureConstr _ _) -> node + + +patterndb :: [(AST,AST)] +patterndb = [ +        (Sum [CaptureTerm "x",CaptureTerm "x",Capture "rest"],  -- x + x + [rest] -> 2*x + [rest] +         Sum [Product [Number 2,Capture "x"],Capture "rest"]), + +        (Sum [CaptureTerm "x",  -- x + n*x + [rest] -> (1+n)*x + [rest] +              Product [CaptureConstr "n" (Number undefined),Capture "x"], +              Capture "rest"], +         Sum [Product [Sum [Number 1,Capture "n"], +                       Capture "x"], +              Capture "rest"]), + +        (Sum [Product [CaptureConstr "n1" (Number undefined),Capture "x"],  -- n1*x + n2*x + [rest] -> (n1+n2)*x + [rest] +              Product [CaptureConstr "n2" (Number undefined),Capture "x"], +              Capture "rest"], + +         Sum [Product [Sum [Capture "n1",Capture "n2"], +                       Capture "x"], +              Capture "rest"]) +    ] + +astChildMap :: AST -> (AST -> AST) -> AST +astChildMap node f = case node of +    (Number _) -> node +    (Variable _) -> node +    (Sum args) -> Sum $ map f args +    (Product args) -> Product $ map f args +    (Negative n) -> Negative $ f n +    (Reciprocal n) -> Reciprocal $ f n +    (Apply name args) -> Apply name $ map f args +    (Capture _) -> node +    (CaptureTerm _) -> node +    (CaptureConstr _ _) -> node + +applyPatterns :: AST -> AST +applyPatterns node = let matches = filter (not . null . fst) $ map (\(pat,repl) -> (astMatch pat node,repl)) patterndb +    in if null matches +        then astChildMap node applyPatterns +        else let ((capdict:_),repl) = head matches  -- TODO: don't take the first option of the first match, but use them all +             in replaceCaptures capdict repl diff --git a/utility.hs b/utility.hs new file mode 100644 index 0000000..c442f80 --- /dev/null +++ b/utility.hs @@ -0,0 +1,20 @@ +module Utility where + +import Data.List + +--if any of the predicates returns true, options returns true +options :: [a -> Bool] -> a -> Bool +options l x = any id [f x | f <- l] + +fixpoint :: (Eq a) => (a -> a) -> a -> a +fixpoint f x = let fx = f x in if fx == x then x else fixpoint f fx + +setcompareBy :: (a -> a -> Bool) -> [a] -> [a] -> Bool +setcompareBy _ [] [] = True +setcompareBy _ [] _ = False +setcompareBy _ _ [] = False +setcompareBy p a@(x:xs) b = length a == length b && setcompareBy p xs (deleteBy p x b) + +deleteIndex :: Int -> [a] -> [a] +deleteIndex 0 (_:xs) = xs +deleteIndex i (x:xs) | i > 0 = x : deleteIndex (i-1) xs | 
