summaryrefslogtreecommitdiff
path: root/Normalise.hs
blob: 3e4a696beabc43581208aca81a9f24c34bc9f72c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -Wno-deferred-out-of-scope-variables #-}
module Normalise where

import Expr
import Data.List (foldl1', foldl')


data SumCollect =
  SumCollect [Integer]  -- ^ literals
             [Expr]   -- ^ positive terms
             [Expr]   -- ^ negative terms

instance Semigroup SumCollect where
  SumCollect a b c <> SumCollect a' b' c' =
    SumCollect (a <> a') (b <> b') (c <> c')

instance Monoid SumCollect where
  mempty = SumCollect [] [] []

scNegate :: SumCollect -> SumCollect
scNegate (SumCollect n post negt) = SumCollect (map negate n) negt post

collectSum :: Expr -> SumCollect
collectSum = \case
  EInfix e1 "+" e2 -> collectSum e1 <> collectSum e2
  EInfix e1 "-" e2 -> collectSum e1 <> scNegate (collectSum e2)
  EParens e -> collectSum e
  EPrefix "+" (ELitInt n) -> SumCollect [n] [] []
  e -> SumCollect [] [e] []

normalise :: Expr -> Expr
normalise e
  | SumCollect literals posterms negterms <- collectSum e
  , length literals + length posterms + length negterms > 1
  = case (sum literals, map normalise posterms, map normalise negterms) of
      (l, [], [])
        | l < 0     -> EPrefix "-" (ELitInt (-l))
        | otherwise -> EPrefix "+" (ELitInt l)
      (0, [], nt0 : nts) ->
        foldl' (einfix "-") (EPrefix "-" nt0) nts
      (0, pt, nt) ->
        foldl' (einfix "-") (foldl1' (einfix "+") pt) nt
      (l, pt, nt) -> 
        foldl' (einfix "-") (foldl' (einfix "+") (EPrefix "+" (ELitInt l)) pt) nt
normalise e = recurse normalise e

recurse :: (Expr -> Expr) -> Expr -> Expr
recurse f (EInfix e1 n e2) = EInfix (f e1) n (f e2)
recurse f (EPrefix n e) = EPrefix n (f e)
recurse f (EParens e) = EParens (f e)
recurse _ (ELitInt x) = ELitInt x
recurse _ (EVar n) = EVar n