summaryrefslogtreecommitdiff
path: root/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'AST.hs')
-rw-r--r--AST.hs295
1 files changed, 295 insertions, 0 deletions
diff --git a/AST.hs b/AST.hs
new file mode 100644
index 0000000..d4c669b
--- /dev/null
+++ b/AST.hs
@@ -0,0 +1,295 @@
+{-# LANGUAGE TupleSections, BangPatterns, DeriveDataTypeable #-}
+
+module AST where
+
+import qualified Data.Map.Strict as Map
+import Data.List
+import Data.Data
+import Data.Typeable
+import Control.DeepSeq
+
+import PrettyPrint
+import Debug
+
+
+data AST = Number Double
+ | Variable String
+ | Sum [AST]
+ | Product [AST]
+ | Negative AST
+ | Reciprocal AST
+ | Apply String [AST]
+ -- The following only in patterns:
+ | Capture String
+ | CaptureTerm String
+ | CaptureConstr String AST -- AST argument only for constructor; only WHNF
+ deriving (Show,Read,Eq,Typeable,Data)
+
+instance PrettyPrint AST where
+ prettyPrint (Number n) = show n
+
+ prettyPrint (Variable n) = n
+
+ prettyPrint (Sum []) = "(+)"
+ prettyPrint (Sum args) = intercalate " + " $ map prettyPrint args
+
+ prettyPrint (Product []) = "(*)"
+ prettyPrint (Product args) = intercalate "*" $ map gopp args
+ where gopp s@(Sum _) = '(' : prettyPrint s ++ ")"
+ gopp n = prettyPrint n
+
+ prettyPrint (Negative n) = '-' : case n of
+ s@(Sum _) -> '(' : prettyPrint s ++ ")"
+ n -> prettyPrint n
+
+ prettyPrint (Reciprocal n) = "1/" ++ case n of
+ s@(Sum _) -> '(' : prettyPrint s ++ ")"
+ s@(Product _) -> '(' : prettyPrint s ++ ")"
+ s@(Reciprocal _) -> '(' : prettyPrint s ++ ")"
+ n -> prettyPrint n
+
+ prettyPrint (Apply name args) = name ++ "(" ++ intercalate "," (map prettyPrint args) ++ ")"
+
+ prettyPrint (Capture name) = '[' : name ++ "]"
+
+ prettyPrint (CaptureTerm name) = '[' : '[' : name ++ "]]"
+
+ prettyPrint (CaptureConstr name c) = '[' : name ++ ":" ++ showConstr (toConstr c) ++ "]"
+
+instance NFData AST where
+ rnf (Number !_) = ()
+ rnf (Variable !_) = ()
+ rnf (Sum l) = seq (length $ map rnf l) ()
+ rnf (Product l) = seq (length $ map rnf l) ()
+ rnf (Negative n) = rnf n
+ rnf (Reciprocal n) = rnf n
+ rnf (Apply !_ l) = seq (length $ map rnf l) ()
+ rnf (Capture !_) = ()
+ rnf (CaptureTerm !_) = ()
+ rnf (CaptureConstr !_ !_) = () -- explicitly not deepseq'ing the ast node
+
+instance Ord AST where
+ compare (Number a) (Number b) = compare a b
+ compare (Variable a) (Variable b) = compare a b
+ compare (Sum a) (Sum b) = compare a b
+ compare (Product a) (Product b) = compare a b
+ compare (Negative a) (Negative b) = compare a b
+ compare (Reciprocal a) (Reciprocal b) = compare a b
+ compare (Apply n1 a) (Apply n2 b) = let r = compare n1 n2 in if r == EQ then compare a b else r
+ compare (Capture a) (Capture b) = compare a b
+ compare (CaptureTerm a) (CaptureTerm b) = compare a b
+ compare (CaptureConstr n1 node1) (CaptureConstr n2 node2) = let r = compare n1 n2 in if r == EQ then compare node1 node2 else r
+
+ compare (Capture _) _ = LT -- Unbounded captures first for efficient
+ compare _ (Capture _) = GT -- extraction with span isCapture
+ compare (Number _) _ = LT
+ compare _ (Number _) = GT
+ compare (Variable _) _ = LT
+ compare _ (Variable _) = GT
+ compare (Sum _) _ = LT
+ compare _ (Sum _) = GT
+ compare (Product _) _ = LT
+ compare _ (Product _) = GT
+ compare (Negative _) _ = LT
+ compare _ (Negative _) = GT
+ compare (Reciprocal _) _ = LT
+ compare _ (Reciprocal _) = GT
+ compare (Apply _ _) _ = LT
+ compare _ (Apply _ _) = GT
+ compare (CaptureTerm _) _ = LT
+ compare _ (CaptureTerm _) = GT
+ -- compare (CaptureConstr _ _) _ = LT
+ -- compare _ (CaptureConstr _ _) = GT
+
+
+astIsNumber :: AST -> Bool
+astIsNumber (Number _) = True
+astIsNumber _ = False
+
+astIsCapture :: AST -> Bool
+astIsCapture (Capture _) = True
+astIsCapture _ = False
+
+
+astFromNumber :: AST -> Double
+astFromNumber (Number n) = n
+
+
+astMatchSimple :: AST -> AST -> Bool
+astMatchSimple pat sub = let res = {-(\x -> trace (" !! RESULT: " ++ show x ++ " !! ") x) $-} astMatch pat sub
+ in if null res
+ then False
+ else any Map.null res
+
+
+astMatch :: AST -- pattern
+ -> AST -- subject
+ -> [Map.Map String AST] -- list of possible capture assignments
+astMatch pat sub = assertS "No captures in astMatch subject" (not $ hasCaptures sub) $
+ case pat of
+ Number x -> case sub of
+ Number y | x == y -> [Map.empty]
+ _ -> []
+
+ Variable name -> case sub of
+ Variable name2 | name == name2 -> [Map.empty]
+ _ -> []
+
+ Sum [term] -> case sub of
+ Sum l2 -> matchList Sum [term] l2
+ s -> astMatch term s
+
+ Sum l -> case sub of
+ Sum l2 -> matchList Sum l l2
+ _ -> []
+
+ Product [term] -> case sub of
+ Product l2 -> matchList Product [term] l2
+ s -> astMatch term s
+
+ Product l -> case sub of
+ Product l2 -> matchList Product l l2
+ _ -> []
+
+ Negative n -> case sub of
+ Negative n2 -> astMatch n n2
+ _ -> []
+
+ Reciprocal n -> case sub of
+ Reciprocal n2 -> astMatch n n2
+ _ -> []
+
+ Apply name l -> case sub of
+ Apply name2 l2 | name == name2 -> matchOrderedList l l2
+ _ -> []
+
+ Capture name -> [Map.singleton name sub]
+
+ CaptureTerm name -> [Map.singleton name sub]
+
+ CaptureConstr name constr ->
+ if toConstr sub == toConstr constr
+ then [Map.singleton name sub]
+ else []
+
+
+matchList :: ([AST] -> AST) -- AST constructor for this list (for insertion in capture)
+ -> [AST] -- unordered patterns
+ -> [AST] -- unordered subjects
+ -> [Map.Map String AST] -- list of possible capture assignments
+matchList constr pats subs =
+ let ordered = sort pats
+ (captures,nocaps) = span astIsCapture ordered
+ in assertS "At most one capture in sum/product" (length captures <= 1) $ case captures of
+ [] -> matchListDBG Nothing nocaps subs
+ [c] -> matchListDBG (Just c) nocaps subs
+ where matchList' :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST]
+ matchList' Nothing [] [] = [Map.empty]
+ matchList' Nothing [] _ = []
+ matchList' (Just (Capture name)) [] subs = [Map.singleton name $ constr subs]
+ matchList' (Just node) [] subs = astMatch node (constr subs)
+ matchList' mcap (pat:pats) subs =
+ let firstmatches = concat $ mapDel (\s other -> map (,other) $ astMatch pat s) subs
+ processed = concat
+ $ map (\(ass,rest) ->
+ let replpats = map (replaceCaptures ass) pats
+ replmcap = fmap (replaceCaptures ass) mcap
+ in map (Map.union ass) $ matchListDBG replmcap replpats rest)
+ firstmatches
+ in {-trace ("firstmatches = "++show firstmatches) $ trace ("processed = "++show processed) $-} processed
+
+ matchListDBG :: Maybe AST -> [AST] -> [AST] -> [Map.Map String AST]
+ matchListDBG mcap pats subs = {-force $ trace ("\n<< "++show (mcap,pats,subs)++" >>\n")
+ $-} matchList' mcap pats subs
+
+
+matchOrderedList :: [AST] -- ordered patterns
+ -> [AST] -- ordered subjects
+ -> [Map.Map String AST] -- list of possible capture assignments
+matchOrderedList [] [] = [Map.empty]
+matchOrderedList [] _ = []
+matchOrderedList _ [] = []
+matchOrderedList (pat:pats) (sub:subs) =
+ let opts = astMatch pat sub
+ newpatsopts = [(map (replaceCaptures opt) pats,opt) | opt <- opts]
+ -- ^ list of possible refined versions of the (rest of the) pattern list
+ in {-trace (show (pat:pats) ++ " -- " ++ show (sub:subs)) $ traceShow opts $-}
+ concat $ map (\(newpats,opt) -> map (Map.union opt)
+ $ matchOrderedList newpats subs) newpatsopts
+
+
+flattenSingle :: AST -> AST
+flattenSingle (Sum args) =
+ let listify (Sum a) = a
+ listify node = [node]
+ in Sum $ concat $ map listify args
+flattenSingle (Product args) =
+ let listify (Product a) = a
+ listify node = [node]
+ in Product $ concat $ map listify args
+flattenSingle node = node
+
+replaceCaptures :: Map.Map String AST -> AST -> AST
+replaceCaptures mp n = case n of
+ Number _ -> n
+ Variable _ -> n
+ Sum l -> flattenSingle $ Sum $ map (replaceCaptures mp) l
+ Product l -> flattenSingle $ Product $ map (replaceCaptures mp) l
+ Negative n2 -> Negative $ replaceCaptures mp n2
+ Reciprocal n2 -> Reciprocal $ replaceCaptures mp n2
+ Apply name n2 -> Apply name $ map (replaceCaptures mp) n2
+ Capture name -> maybe n id $ Map.lookup name mp
+ CaptureTerm name -> maybe n id $ Map.lookup name mp
+ CaptureConstr name c -> maybe n id $ Map.lookup name mp
+
+
+hasCaptures :: AST -> Bool
+hasCaptures n = case n of
+ Number _ -> False
+ Variable _ -> False
+ Sum l -> any id [hasCaptures m | m <- l]
+ Product l -> any id [hasCaptures m | m <- l]
+ Negative m -> hasCaptures m
+ Reciprocal m -> hasCaptures m
+ Apply _ l -> any id [hasCaptures m | m <- l]
+ Capture _ -> True
+ CaptureTerm _ -> True
+ CaptureConstr _ _ -> True
+
+
+assert :: Bool -> a -> a
+assert = assertS "(no reason)"
+
+assertS :: String -> Bool -> a -> a
+assertS _ True = id
+assertS s False = error $ "Condition not satisfied in assert: " ++ s
+
+
+mapDel :: (a -> [a] -> b) -> [a] -> [b]
+mapDel _ [] = []
+mapDel f l =
+ let splits = zip l
+ $ map (\(a,b:bs) -> a++bs)
+ $ iterate (\(a,b:bs) -> (a++[b],bs)) ([],l)
+ in map (uncurry f) splits
+
+
+-- some testing things
+--pat = Sum [Number 1,Capture "x",Negative $ Capture "x"]
+--sub = Sum [Number 4,Variable "a",Number 1,Negative $ Sum [Variable "a",Number 4]]
+
+--pat = Sum [Negative $ Capture "x"]
+--sub = Sum [Negative $ Sum [Variable "a",Number 4]]
+
+--pat = Sum [Capture "x",Negative (Capture "x"),CaptureTerm "y",CaptureTerm "z"]
+--sub = let x = Reciprocal (Number 7) in Sum [x,Negative x,Number 7,Number 8]
+
+--pat = Sum [CaptureTerm "x",CaptureTerm "y",Capture "rest",Negative $ Capture "rest"]
+--sub = Sum [Number 1,Number 2,Negative $ Number 1,Variable "kaas",Negative $ Sum [Negative $ Number 1,Variable "kaas"]]
+
+--pat = Sum [Product [Capture "x"],Product [Capture "x"]]
+--sub = Sum [Product [Number 1],Product [Number 1]]
+
+--main = do
+-- let res = astMatch pat sub
+-- deepseq res $ putStrLn $ "\x1B[32m"++show res++"\x1B[0m"