{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE LambdaCase #-} {-# OPTIONS_GHC -Wno-deferred-out-of-scope-variables #-} module Normalise where import Expr import Data.List (foldl1', foldl') data SumCollect = SumCollect [Integer] -- ^ literals [Expr] -- ^ positive terms [Expr] -- ^ negative terms instance Semigroup SumCollect where SumCollect a b c <> SumCollect a' b' c' = SumCollect (a <> a') (b <> b') (c <> c') instance Monoid SumCollect where mempty = SumCollect [] [] [] scNegate :: SumCollect -> SumCollect scNegate (SumCollect n post negt) = SumCollect (map negate n) negt post collectSum :: Expr -> SumCollect collectSum = \case EInfix e1 "+" e2 -> collectSum e1 <> collectSum e2 EInfix e1 "-" e2 -> collectSum e1 <> scNegate (collectSum e2) EParens e -> collectSum e EPrefix "+" (ELitInt n) -> SumCollect [n] [] [] e -> SumCollect [] [e] [] normalise :: Expr -> Expr normalise e | SumCollect literals posterms negterms <- collectSum e , length literals + length posterms + length negterms > 1 = case (sum literals, map normalise posterms, map normalise negterms) of (l, [], []) | l < 0 -> EPrefix "-" (ELitInt (-l)) | otherwise -> EPrefix "+" (ELitInt l) (0, [], nt0 : nts) -> foldl' (einfix "-") (EPrefix "-" nt0) nts (0, pt, nt) -> foldl' (einfix "-") (foldl1' (einfix "+") pt) nt (l, pt, nt) -> foldl' (einfix "-") (foldl' (einfix "+") (EPrefix "+" (ELitInt l)) pt) nt normalise e = recurse normalise e recurse :: (Expr -> Expr) -> Expr -> Expr recurse f (EInfix e1 n e2) = EInfix (f e1) n (f e2) recurse f (EPrefix n e) = EPrefix n (f e) recurse f (EParens e) = EParens (f e) recurse _ (ELitInt x) = ELitInt x recurse _ (EVar n) = EVar n