{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where

import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))

import AST
import AST.Env
import Data


data Count = Zero | One | Many
  deriving (Show, Eq, Ord)

instance Semigroup Count where
  Zero <> n = n
  n <> Zero = n
  _ <> _ = Many
instance Monoid Count where
  mempty = Zero

data Occ = Occ { _occLexical :: Count
               , _occRuntime :: Count }
  deriving (Eq, Generic)
  deriving (Semigroup, Monoid) via Generically Occ

instance Show Occ where
  showsPrec d (Occ l r) = showParen (d > 10) $
    showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r

-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)

-- | This code is executed many times
scaleMany :: Occ -> Occ
scaleMany (Occ l Zero) = Occ l Zero
scaleMany (Occ l _) = Occ l Many

occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
  getConst . occCountGeneral
    (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
    (\(Const o) -> Const o)
    (\(Const o1) (Const o2) -> Const (o1 <||> o2))
    (\(Const o) -> Const (scaleMany o))


data OccEnv env where
  OccEnd :: OccEnv env  -- not necessarily top!
  OccPush :: OccEnv env -> Occ -> OccEnv (t : env)

instance Semigroup (OccEnv env) where
  OccEnd <> e = e
  e <> OccEnd = e
  OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')

instance Monoid (OccEnv env) where
  mempty = OccEnd

onehotOccEnv :: Idx env t -> Occ -> OccEnv env
onehotOccEnv IZ v = OccPush OccEnd v
onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty

(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
OccEnd <||>! e = e
e <||>! OccEnd = e
OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')

scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)

occEnvPop :: OccEnv (t : env) -> OccEnv env
occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd

occCountAll :: Expr x env t -> OccEnv env
occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv

occCountGeneral :: forall r env t x.
                   (forall env'. Monoid (r env'))
                => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env')  -- ^ one-hot
                -> (forall env' a. r (a : env') -> r env')  -- ^ unpush
                -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation
                -> (forall env'. r env' -> r env')  -- ^ scale-many
                -> Expr x env t -> r env
occCountGeneral onehot unpush alter many = go WId
  where
    go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
    go w = \case
      EVar _ _ i -> onehot w i (Occ One One)
      ELet _ rhs body -> re rhs <> re1 body
      EPair _ a b -> re a <> re b
      EFst _ e -> re e
      ESnd _ e -> re e
      ENil _ -> mempty
      EInl _ _ e -> re e
      EInr _ _ e -> re e
      ECase _ e a b -> re e <> (re1 a `alter` re1 b)
      ENothing _ _ -> mempty
      EJust _ e -> re e
      EMaybe _ a b e -> re a <> re1 b <> re e
      EConstArr{} -> mempty
      EBuild _ _ a b -> re a <> many (re1 b)
      EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
      ESum1Inner _ e -> re e
      EUnit _ e -> re e
      EReplicate1Inner _ a b -> re a <> re b
      EMaximum1Inner _ e -> re e
      EMinimum1Inner _ e -> re e
      EConst{} -> mempty
      EIdx0 _ e -> re e
      EIdx1 _ a b -> re a <> re b
      EIdx _ a b -> re a <> re b
      EShape _ e -> re e
      EOp _ _ e -> re e
      ECustom _ _ _ _ _ _ _ a b -> re a <> re b
      EWith a b -> re a <> re1 b
      EAccum _ a b e -> re a <> re b <> re e
      EZero _ -> mempty
      EPlus _ a b -> re a <> re b
      EOneHot _ _ a b -> re a <> re b
      EError{} -> mempty
      where
        re :: Monoid (r env') => Expr x env' t'' -> r env'
        re = go w

        re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
        re1 = unpush . go (WSink .> w)


deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil OccEnd k = k SETop
deleteUnused (_ `SCons` env) OccEnd k =
  deleteUnused env OccEnd $ \sub -> k (SENo sub)
deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
  deleteUnused env occenv $ \sub ->
    case count of Zero -> k (SENo sub)
                  _    -> k (SEYes sub)

unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
  subst (\x t i -> case sinkViaSubenv i sub of
                     Just i' -> EVar x t i'
                     Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
  where
    sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
    sinkViaSubenv IZ (SEYes _) = Just IZ
    sinkViaSubenv IZ (SENo _) = Nothing
    sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
    sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub