{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
module Simplify where

import AST
import AST.Count


simplifyN :: Int -> Ex env t -> Ex env t
simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify

simplify :: Ex env t -> Ex env t
simplify = \case
  -- inlining
  ELet _ rhs body
    | Occ lexOcc runOcc <- occCount IZ body
    , lexOcc <= One  -- prevent code size blowup
    , runOcc <= One  -- prevent runtime increase
    -> simplify (subst1 rhs body)
    | cheapExpr rhs
    -> simplify (subst1 rhs body)

  -- let splitting
  ELet _ (EPair _ a b) body ->
    simplify $
      ELet ext a $
      ELet ext (weakenExpr WSink b) $
        subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
                             IS i -> EVar ext t (IS (IS i)))
              body

  -- beta rules for products
  EFst _ (EPair _ e _) -> simplify e
  ESnd _ (EPair _ _ e) -> simplify e

  -- beta rules for coproducts
  ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
  ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)

  -- TODO: array indexing (index of build, index of fold)

  -- TODO: constant folding for operations

  -- eta rule for return+bind
  EMBind (EMReturn _ a) b -> simplify (ELet ext a b)

  -- associativity of bind
  EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c)))

  -- bind-let commute
  EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c)))

  -- return-let commute
  EMReturn env (ELet _ a b) -> simplify (ELet ext a (EMReturn env b))

  EVar _ t i -> EVar ext t i
  ELet _ a b -> ELet ext (simplify a) (simplify b)
  EPair _ a b -> EPair ext (simplify a) (simplify b)
  EFst _ e -> EFst ext (simplify e)
  ESnd _ e -> ESnd ext (simplify e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext t (simplify e)
  EInr _ t e -> EInr ext t (simplify e)
  ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b)
  EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b)
  EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e)
  EFold1 _ a b -> EFold1 ext (simplify a) (simplify b)
  EConst _ t v -> EConst ext t v
  EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b)
  EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es)
  EOp _ op e -> EOp ext op (simplify e)
  EMOne t i e -> EMOne t i (simplify e)
  EMScope e -> EMScope (simplify e)
  EMReturn t e -> EMReturn t (simplify e)
  EMBind a b -> EMBind (simplify a) (simplify b)
  EError t s -> EError t s

cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
  EVar{} -> True
  ENil{} -> True
  EConst{} -> True
  _ -> False

subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
                                    IS i -> EVar x t i

subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
      -> Expr x env t -> Expr x env' t
subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId

subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
       -> env' :> envOut
       -> Expr x env t
       -> Expr x envOut t
subst' f w = \case
  EVar x t i -> f x t w i
  ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
  EPair x a b -> EPair x (subst' f w a) (subst' f w b)
  EFst x e -> EFst x (subst' f w e)
  ESnd x e -> ESnd x (subst' f w e)
  ENil x -> ENil x
  EInl x t e -> EInl x t (subst' f w e)
  EInr x t e -> EInr x t (subst' f w e)
  ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
  EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
  EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
  EConst x t v -> EConst x t v
  EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
  EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
  EOp x op e -> EOp x op (subst' f w e)
  EMOne t i e -> EMOne t i (subst' f w e)
  EMScope e -> EMScope (subst' f w e)
  EMReturn t e -> EMReturn t (subst' f w e)
  EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EError t s -> EError t s
  where
    sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
          -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
    sinkF f' x' t w' = \case
      IZ -> EVar x' t (w' @> IZ)
      IS i -> f' x' t (WPop w') i

    sinkFN :: SNat n
           -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
           -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
    sinkFN SZ f' x t w' i = f' x t w' i
    sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
    sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i