{-# LANGUAGE TupleSections, BangPatterns, DeriveDataTypeable #-} module AST where import qualified Data.Map.Strict as Map import Data.List import Data.Data import Control.DeepSeq import PrettyPrint -- import Debug data AST = Number Double | Variable String | Sum [AST] | Product [AST] | Negative AST | Reciprocal AST | Apply String [AST] -- The following only in patterns: | Capture String | CaptureTerm String | CaptureConstr String AST -- AST argument only for constructor; only WHNF deriving (Show,Read,Eq,Typeable,Data) instance PrettyPrint AST where prettyPrint (Number n) = show n prettyPrint (Variable n) = n prettyPrint (Sum []) = "(+)" prettyPrint (Sum args) = intercalate " + " $ map prettyPrint args prettyPrint (Product []) = "(*)" prettyPrint (Product args) = intercalate "*" $ map gopp args where gopp s@(Sum _) = '(' : prettyPrint s ++ ")" gopp n = prettyPrint n prettyPrint (Negative n) = '-' : case n of s@(Sum _) -> '(' : prettyPrint s ++ ")" _ -> prettyPrint n prettyPrint (Reciprocal n) = "1/" ++ case n of s@(Sum _) -> '(' : prettyPrint s ++ ")" s@(Product _) -> '(' : prettyPrint s ++ ")" s@(Reciprocal _) -> '(' : prettyPrint s ++ ")" _ -> prettyPrint n prettyPrint (Apply name args) = name ++ "(" ++ intercalate "," (map prettyPrint args) ++ ")" prettyPrint (Capture name) = '[' : name ++ "]" prettyPrint (CaptureTerm name) = '[' : '[' : name ++ "]]" prettyPrint (CaptureConstr name c) = '[' : name ++ ":" ++ showConstr (toConstr c) ++ "]" instance NFData AST where rnf (Number !_) = () rnf (Variable !_) = () rnf (Sum l) = seq (length $ map rnf l) () rnf (Product l) = seq (length $ map rnf l) () rnf (Negative n) = rnf n rnf (Reciprocal n) = rnf n rnf (Apply !_ l) = seq (length $ map rnf l) () rnf (Capture !_) = () rnf (CaptureTerm !_) = () rnf (CaptureConstr !_ !_) = () -- explicitly not deepseq'ing the ast node instance Ord AST where compare (Number a) (Number b) = compare a b compare (Variable a) (Variable b) = compare a b compare (Sum a) (Sum b) = compare a b compare (Product a) (Product b) = compare a b compare (Negative a) (Negative b) = compare a b compare (Reciprocal a) (Reciprocal b) = compare a b compare (Apply n1 a) (Apply n2 b) = let r = compare n1 n2 in if r == EQ then compare a b else r compare (Capture a) (Capture b) = compare a b compare (CaptureTerm a) (CaptureTerm b) = compare a b compare (CaptureConstr n1 node1) (CaptureConstr n2 node2) = let r = compare n1 n2 in if r == EQ then compare node1 node2 else r compare (Capture _) _ = LT -- Unbounded captures first for efficient compare _ (Capture _) = GT -- extraction with span isCapture compare (Number _) _ = LT compare _ (Number _) = GT compare (Variable _) _ = LT compare _ (Variable _) = GT compare (Sum _) _ = LT compare _ (Sum _) = GT compare (Product _) _ = LT compare _ (Product _) = GT compare (Negative _) _ = LT compare _ (Negative _) = GT compare (Reciprocal _) _ = LT compare _ (Reciprocal _) = GT compare (Apply _ _) _ = LT compare _ (Apply _ _) = GT compare (CaptureTerm _) _ = LT compare _ (CaptureTerm _) = GT -- compare (CaptureConstr _ _) _ = LT -- compare _ (CaptureConstr _ _) = GT astIsNumber :: AST -> Bool astIsNumber (Number _) = True astIsNumber _ = False astIsCapture :: AST -> Bool astIsCapture (Capture _) = True astIsCapture _ = False astFromNumber :: AST -> Double astFromNumber (Number n) = n astFromNumber _ = undefined astMatchSimple :: AST -> AST -> Bool astMatchSimple pat sub = let res = {-(\x -> trace (" !! RESULT: " ++ show x ++ " !! ") x) $-} astMatch pat sub in if null res then False else any Map.null res astMatch :: AST -- pattern -> AST -- subject -> [Map.Map String AST] -- list of possible capture assignments astMatch pat sub = assertS "No captures in astMatch subject" (not $ hasCaptures sub) $ case pat of Number x -> case sub of Number y | x == y -> [Map.empty] _ -> [] Variable name -> case sub of Variable name2 | name == name2 -> [Map.empty] _ -> [] Sum [term] -> case sub of Sum l2 -> matchList Sum [term] l2 s -> astMatch term s Sum l -> case sub of Sum l2 -> matchList Sum l l2 _ -> [] Product [term] -> case sub of Product l2 -> matchList Product [term] l2 s -> astMatch term s Product l -> case sub of Product l2 -> matchList Product l l2 _ -> [] Negative n -> case sub of Negative n2 -> astMatch n n2 _ -> [] Reciprocal n -> case sub of Reciprocal n2 -> astMatch n n2 _ -> [] Apply name l -> case sub of Apply name2 l2 | name == name2 -> matchOrderedList l l2 _ -> [] Capture name -> [Map.singleton name sub] CaptureTerm name -> [Map.singleton name sub] CaptureConstr name constr -> if toConstr sub == toConstr constr then [Map.singleton name sub] else [] matchList :: ([AST] -> AST) -- AST constructor for this list (for insertion in capture) -> [AST] -- unordered patterns -> [AST] -- unordered subjects -> [Map.Map String AST] -- list of possible capture assignments matchList constr toppats topsubs = let ordered = sort toppats (captures,nocaps) = span astIsCapture ordered in assertS "At most one capture in sum/product" (length captures <= 1) $ case captures of [] -> matchListDBG Nothing nocaps topsubs [c] -> matchListDBG (Just c) nocaps topsubs _ -> undefined where matchList' :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] matchList' Nothing [] [] = [Map.empty] matchList' Nothing [] _ = [] matchList' (Just (Capture name)) [] subs = [Map.singleton name $ constr subs] matchList' (Just node) [] subs = astMatch node (constr subs) matchList' mcap (pat:pats) subs = let firstmatches = concat $ mapDel (\s other -> map (,other) $ astMatch pat s) subs processed = concat $ map (\(ass,rest) -> let replpats = map (replaceCaptures ass) pats replmcap = fmap (replaceCaptures ass) mcap in map (Map.union ass) $ matchListDBG replmcap replpats rest) firstmatches in {-trace ("firstmatches = "++show firstmatches) $ trace ("processed = "++show processed) $-} processed matchListDBG :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST] matchListDBG mcap pats subs = {-force $ trace ("\n<< "++show (mcap,pats,subs)++" >>\n") $-} matchList' mcap pats subs matchOrderedList :: [AST] -- ordered patterns -> [AST] -- ordered subjects -> [Map.Map String AST] -- list of possible capture assignments matchOrderedList [] [] = [Map.empty] matchOrderedList [] _ = [] matchOrderedList _ [] = [] matchOrderedList (pat:pats) (sub:subs) = let opts = astMatch pat sub newpatsopts = [(map (replaceCaptures opt) pats,opt) | opt <- opts] -- ^ list of possible refined versions of the (rest of the) pattern list in {-trace (show (pat:pats) ++ " -- " ++ show (sub:subs)) $ traceShow opts $-} concat $ map (\(newpats,opt) -> map (Map.union opt) $ matchOrderedList newpats subs) newpatsopts flattenSingle :: AST -> AST flattenSingle (Sum args) = let listify (Sum a) = a listify node = [node] in Sum $ concat $ map listify args flattenSingle (Product args) = let listify (Product a) = a listify node = [node] in Product $ concat $ map listify args flattenSingle node = node replaceCaptures :: Map.Map String AST -> AST -> AST replaceCaptures mp n = case n of Number _ -> n Variable _ -> n Sum l -> flattenSingle $ Sum $ map (replaceCaptures mp) l Product l -> flattenSingle $ Product $ map (replaceCaptures mp) l Negative n2 -> Negative $ replaceCaptures mp n2 Reciprocal n2 -> Reciprocal $ replaceCaptures mp n2 Apply name n2 -> Apply name $ map (replaceCaptures mp) n2 Capture name -> maybe n id $ Map.lookup name mp CaptureTerm name -> maybe n id $ Map.lookup name mp CaptureConstr name _ -> maybe n id $ Map.lookup name mp hasCaptures :: AST -> Bool hasCaptures n = case n of Number _ -> False Variable _ -> False Sum l -> any id [hasCaptures m | m <- l] Product l -> any id [hasCaptures m | m <- l] Negative m -> hasCaptures m Reciprocal m -> hasCaptures m Apply _ l -> any id [hasCaptures m | m <- l] Capture _ -> True CaptureTerm _ -> True CaptureConstr _ _ -> True assert :: Bool -> a -> a assert = assertS "(no reason)" assertS :: String -> Bool -> a -> a assertS _ True = id assertS s False = error $ "Condition not satisfied in assert: " ++ s mapDel :: (a -> [a] -> b) -> [a] -> [b] mapDel _ [] = [] mapDel f l = let splits = zip l $ map (\(a,_:bs) -> a++bs) $ iterate (\(a,b:bs) -> (a++[b],bs)) ([],l) in map (uncurry f) splits -- some testing things --pat = Sum [Number 1,Capture "x",Negative $ Capture "x"] --sub = Sum [Number 4,Variable "a",Number 1,Negative $ Sum [Variable "a",Number 4]] --pat = Sum [Negative $ Capture "x"] --sub = Sum [Negative $ Sum [Variable "a",Number 4]] --pat = Sum [Capture "x",Negative (Capture "x"),CaptureTerm "y",CaptureTerm "z"] --sub = let x = Reciprocal (Number 7) in Sum [x,Negative x,Number 7,Number 8] --pat = Sum [CaptureTerm "x",CaptureTerm "y",Capture "rest",Negative $ Capture "rest"] --sub = Sum [Number 1,Number 2,Negative $ Number 1,Variable "kaas",Negative $ Sum [Negative $ Number 1,Variable "kaas"]] --pat = Sum [Product [Capture "x"],Product [Capture "x"]] --sub = Sum [Product [Number 1],Product [Number 1]] --main = do -- let res = astMatch pat sub -- deepseq res $ putStrLn $ "\x1B[32m"++show res++"\x1B[0m"