{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Simplify where

import Data.Function (fix)
import Data.Monoid (Any(..))

import AST
import AST.Count
import Data


simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify

simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplify = let ?accumInScope = checkAccumInScope @env knownEnv in snd . simplify'

simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplifyFix =
  let ?accumInScope = checkAccumInScope @env knownEnv
  in fix $ \loop e ->
            let (Any act, e') = simplify' e
            in if act then loop e' else e'

simplify' :: (?accumInScope :: Bool) => Ex env t -> (Any, Ex env t)
simplify' = \case
  -- inlining
  ELet _ rhs body
    | cheapExpr rhs
    -> acted $ simplify' (subst1 rhs body)

    | Occ lexOcc runOcc <- occCount IZ body
    , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One)  -- without effects, normal rules apply
          || (lexOcc == One && runOcc == One)  -- with effects, linear inlining is still allowed, but weakening is not
    -> acted $ simplify' (subst1 rhs body)

  -- let splitting
  ELet _ (EPair _ a b) body ->
    acted $ simplify' $
      ELet ext a $
      ELet ext (weakenExpr WSink b) $
        subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
                             IS i -> EVar ext t (IS (IS i)))
              body

  -- let rotation
  ELet _ (ELet _ rhs a) b ->
    acted $ simplify' $
      ELet ext rhs $
      ELet ext a $
        weakenExpr (WCopy WSink) (snd (simplify' b))

  -- beta rules for products
  EFst _ (EPair _ e e') | not (hasAdds e') -> acted $ simplify' e
  ESnd _ (EPair _ e' e) | not (hasAdds e') -> acted $ simplify' e

  -- beta rules for coproducts
  ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs)
  ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs)

  -- beta rules for maybe
  EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
  EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1

  -- let floating to facilitate beta reduction
  EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
  ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
  ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
  EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
  EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))

  -- projection down-commuting
  EFst _ (ECase _ e1 (EPair _ e2 _) (EPair _ e3 _)) ->
    acted $ simplify' $
      ECase ext e1 e2 e3
  ESnd _ (ECase _ e1 (EPair _ _ e2) (EPair _ _ e3)) ->
    acted $ simplify' $
      ECase ext e1 e2 e3

  -- TODO: array indexing (index of build, index of fold)

  -- TODO: beta rules for maybe

  -- TODO: constant folding for operations

  -- TODO: properly concatenate accum/onehot
  EAccum SZ _ (EOneHot _ i idx val) acc ->
    acted $ simplify' $
      EAccum i idx val acc
  EAccum _ _ (EZero _) _ -> (Any True, ENil ext)
  EPlus _ (EZero _) e -> acted $ simplify' e
  EPlus _ e (EZero _) -> acted $ simplify' e
  EOneHot _ SZ _ e -> acted $ simplify' e

  EVar _ t i -> pure $ EVar ext t i
  ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
  EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
  EFst _ e -> EFst ext <$> simplify' e
  ESnd _ e -> ESnd ext <$> simplify' e
  ENil _ -> pure $ ENil ext
  EInl _ t e -> EInl ext t <$> simplify' e
  EInr _ t e -> EInr ext t <$> simplify' e
  ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
  ENothing _ t -> pure $ ENothing ext t
  EJust _ e -> EJust ext <$> simplify' e
  EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
  EConstArr _ n t v -> pure $ EConstArr ext n t v
  EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
  EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c
  ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
  EUnit _ e -> EUnit ext <$> simplify' e
  EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
  EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
  EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
  EConst _ t v -> pure $ EConst ext t v
  EIdx0 _ e -> EIdx0 ext <$> simplify' e
  EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
  EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
  EShape _ e -> EShape ext <$> simplify' e
  EOp _ op e -> EOp ext op <$> simplify' e
  ECustom _ s t p a b c e1 e2 ->
    ECustom ext s t p
      <$> (let ?accumInScope = False in simplify' a)
      <*> (let ?accumInScope = False in simplify' b)
      <*> (let ?accumInScope = False in simplify' c)
      <*> simplify' e1 <*> simplify' e2
  EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
  EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
  EZero t -> pure $ EZero t
  EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b
  EOneHot t i a b -> EOneHot t i <$> simplify' a <*> simplify' b
  EError t s -> pure $ EError t s

acted :: (Any, a) -> (Any, a)
acted (_, x) = (Any True, x)

cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
  EVar{} -> True
  ENil{} -> True
  EConst{} -> True
  _ -> False

-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
hasAdds :: Expr x env t -> Bool
hasAdds = \case
  EVar _ _ _ -> False
  ELet _ rhs body -> hasAdds rhs || hasAdds body
  EPair _ a b -> hasAdds a || hasAdds b
  EFst _ e -> hasAdds e
  ESnd _ e -> hasAdds e
  ENil _ -> False
  EInl _ _ e -> hasAdds e
  EInr _ _ e -> hasAdds e
  ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
  ENothing _ _ -> False
  EJust _ e -> hasAdds e
  EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
  EConstArr _ _ _ _ -> False
  EBuild _ _ a b -> hasAdds a || hasAdds b
  EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c
  ESum1Inner _ e -> hasAdds e
  EUnit _ e -> hasAdds e
  EReplicate1Inner _ a b -> hasAdds a || hasAdds b
  EMaximum1Inner _ e -> hasAdds e
  EMinimum1Inner _ e -> hasAdds e
  ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
  EConst _ _ _ -> False
  EIdx0 _ e -> hasAdds e
  EIdx1 _ a b -> hasAdds a || hasAdds b
  EIdx _ a b -> hasAdds a || hasAdds b
  EShape _ e -> hasAdds e
  EOp _ _ e -> hasAdds e
  EWith a b -> hasAdds a || hasAdds b
  EAccum _ _ _ _ -> True
  EZero _ -> False
  EPlus _ a b -> hasAdds a || hasAdds b
  EOneHot _ _ a b -> hasAdds a || hasAdds b
  EError _ _ -> False

checkAccumInScope :: SList STy env -> Bool
checkAccumInScope = \case SNil -> False
                          SCons t env -> check t || checkAccumInScope env
  where
    check :: STy t -> Bool
    check STNil = False
    check (STPair s t) = check s || check t
    check (STEither s t) = check s || check t
    check (STMaybe t) = check t
    check (STArr _ t) = check t
    check (STScal _) = False
    check STAccum{} = True