summaryrefslogtreecommitdiff
path: root/Normalise.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Normalise.hs')
-rw-r--r--Normalise.hs54
1 files changed, 54 insertions, 0 deletions
diff --git a/Normalise.hs b/Normalise.hs
new file mode 100644
index 0000000..3e4a696
--- /dev/null
+++ b/Normalise.hs
@@ -0,0 +1,54 @@
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE LambdaCase #-}
+{-# OPTIONS_GHC -Wno-deferred-out-of-scope-variables #-}
+module Normalise where
+
+import Expr
+import Data.List (foldl1', foldl')
+
+
+data SumCollect =
+ SumCollect [Integer] -- ^ literals
+ [Expr] -- ^ positive terms
+ [Expr] -- ^ negative terms
+
+instance Semigroup SumCollect where
+ SumCollect a b c <> SumCollect a' b' c' =
+ SumCollect (a <> a') (b <> b') (c <> c')
+
+instance Monoid SumCollect where
+ mempty = SumCollect [] [] []
+
+scNegate :: SumCollect -> SumCollect
+scNegate (SumCollect n post negt) = SumCollect (map negate n) negt post
+
+collectSum :: Expr -> SumCollect
+collectSum = \case
+ EInfix e1 "+" e2 -> collectSum e1 <> collectSum e2
+ EInfix e1 "-" e2 -> collectSum e1 <> scNegate (collectSum e2)
+ EParens e -> collectSum e
+ EPrefix "+" (ELitInt n) -> SumCollect [n] [] []
+ e -> SumCollect [] [e] []
+
+normalise :: Expr -> Expr
+normalise e
+ | SumCollect literals posterms negterms <- collectSum e
+ , length literals + length posterms + length negterms > 1
+ = case (sum literals, map normalise posterms, map normalise negterms) of
+ (l, [], [])
+ | l < 0 -> EPrefix "-" (ELitInt (-l))
+ | otherwise -> EPrefix "+" (ELitInt l)
+ (0, [], nt0 : nts) ->
+ foldl' (einfix "-") (EPrefix "-" nt0) nts
+ (0, pt, nt) ->
+ foldl' (einfix "-") (foldl1' (einfix "+") pt) nt
+ (l, pt, nt) ->
+ foldl' (einfix "-") (foldl' (einfix "+") (EPrefix "+" (ELitInt l)) pt) nt
+normalise e = recurse normalise e
+
+recurse :: (Expr -> Expr) -> Expr -> Expr
+recurse f (EInfix e1 n e2) = EInfix (f e1) n (f e2)
+recurse f (EPrefix n e) = EPrefix n (f e)
+recurse f (EParens e) = EParens (f e)
+recurse _ (ELitInt x) = ELitInt x
+recurse _ (EVar n) = EVar n