{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
module AST.Pretty where

import Control.Monad (ap)
import Data.List (intersperse)
import Data.Foldable (toList)
import Data.Functor.Const

import AST
import AST.Count


data Val f env where
  VTop :: Val f '[]
  VPush :: f t -> Val f env -> Val f (t : env)

type SVal = Val (Const String)

valprj :: Val f env -> Idx env t -> f t
valprj (VPush x _) IZ = x
valprj (VPush _ env) (IS i) = valprj env i
valprj VTop i = case i of {}

vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env)
vpushN VNil v = v
vpushN (name :< names) v = VPush (Const name) (vpushN names v)

newtype M a = M { runM :: Int -> (a, Int) }
  deriving (Functor)
instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) }

genId :: M Int
genId = M (\i -> (i, i + 1))

genName :: M String
genName = ('x' :) . show <$> genId

genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn ty idx ex
  | occCount idx ex == mempty = case ty of STNil -> return "()"
                                           _ -> return "_"
  | otherwise                 = genName

ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
  where
    mkVal :: SList STy env -> M (SVal env)
    mkVal SNil = return VTop
    mkVal (SCons _ v) = do
      val <- mkVal v
      name <- genName
      return (VPush (Const name) val)

ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
ppExpr' d val = \case
  EVar _ _ i -> return $ showString $ getConst $ valprj val i

  EPair _ a b -> do
    a' <- ppExpr' 0 val a
    b' <- ppExpr' 0 val b
    return $ showString "(" . a' . showString ", " . b' . showString ")"

  EFst _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "fst " . e'

  ESnd _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "snd " . e'

  ENil _ -> return $ showString "()"

  EInl _ _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Inl " . e'

  EInr _ _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Inr " . e'

  ECase _ e a b -> do
    e' <- ppExpr' 0 val e
    let STEither t1 t2 = typeOf e
    name1 <- genNameIfUsedIn t1 IZ a
    a' <- ppExpr' 0 (VPush (Const name1) val) a
    name2 <- genNameIfUsedIn t2 IZ b
    b' <- ppExpr' 0 (VPush (Const name2) val) b
    return $ showParen (d > 0) $
      showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
      . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"

  EBuild1 _ a b -> do
    a' <- ppExpr' 11 val a
    name <- genNameIfUsedIn (STScal STI64) IZ b
    b' <- ppExpr' 0 (VPush (Const name) val) b
    return $ showParen (d > 10) $
      showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"

  EBuild _ es e -> do
    es' <- mapM (ppExpr' 0 val) es
    names <- mapM (const genName) es  -- TODO generate underscores
    e' <- ppExpr' 0 (vpushN names val) e
    return $ showParen (d > 10) $
      showString "build ["
      . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
      . showString "] (\\["
      . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))
      . showString ("] -> ") . e' . showString ")"

  EFold1 _ a b -> do
    name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
    name2 <- genNameIfUsedIn (typeOf a) IZ a
    a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
    b' <- ppExpr' 11 val b
    return $ showParen (d > 10) $
      showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
      . showString ") " . b'

  EConst _ ty v -> return $ showString $ case ty of
    STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v

  EIdx1 _ a b -> do
    a' <- ppExpr' 9 val a
    b' <- ppExpr' 9 val b
    return $ showParen (d > 8) $ a' . showString " ! " . b'

  EIdx _ e es -> do
    e' <- ppExpr' 9 val e
    es' <- traverse (ppExpr' 0 val) es
    return $ showParen (d > 8) $
      e' . showString " ! "
      . showString "["
      . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
      . showString "]"

  EOp _ op (EPair _ a b)
    | (Infix, ops) <- operator op -> do
        a' <- ppExpr' 9 val a
        b' <- ppExpr' 9 val b
        return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b'

  EOp _ op e -> do
    e' <- ppExpr' 11 val e
    let ops = case operator op of
                (Infix, s) -> "(" ++ s ++ ")"
                (Prefix, s) -> s
    return $ showParen (d > 10) $ showString (ops ++ " ") . e'

  EMOne venv i e -> do
    let venvlen = length (unSList venv)
        varname = 'v' : show (venvlen - idx2int i)
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $
      showString ("one " ++ show varname ++ " ") . e'

  EMScope e -> do
    let venv = case typeOf e of STEVM v _ -> v
        venvlen = length (unSList venv)
        varname = 'v' : show venvlen
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $
      showString ("scope " ++ show varname ++ " ") . e'

  EMReturn _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString ("return ") . e'

  e@EMBind{} -> ppExprDo d val e
  e@ELet{} -> ppExprDo d val e

  -- EMBind a b -> do
  --   let STEVM _ t = typeOf a
  --   a' <- ppExpr' 0 val a
  --   name <- genNameIfUsedIn t IZ b
  --   b' <- ppExpr' 0 (VPush (Const name) val) b
  --   return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'

  EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)

data Binding = MonadBind String ShowS
             | LetBind   String ShowS

ppExprDo :: Int -> SVal env -> Expr x env t -> M ShowS
ppExprDo d val etop = do
  let collect :: SVal env -> Expr x env t -> M ([Binding], ShowS)
      collect val' (EMBind lhs body) = do
        let STEVM _ t = typeOf lhs
        name <- genNameIfUsedIn t IZ body
        (binds, core) <- collect (VPush (Const name) val') body
        lhs' <- ppExpr' 0 val' lhs
        return (MonadBind name lhs' : binds, core)
      collect val' (ELet _ rhs body) = do
        name <- genNameIfUsedIn (typeOf rhs) IZ body
        (binds, core) <- collect (VPush (Const name) val') body
        rhs' <- ppExpr' 0 val' rhs
        return (LetBind name rhs' : binds, core)
      collect val' e = ([],) <$> ppExpr' 0 val' e

      fromLet = \case LetBind n s -> Just (n, s) ; _ -> Nothing

  (binds, core) <- collect val etop

  return $ showParen (d > 0) $ case traverse fromLet binds of
    Just lbinds ->
      let (open, close) = case lbinds of
            [_] -> ("{ ", " }")
            _ -> ("", "")
      in showString ("let " ++ open)
         . foldr (.) id
             (intersperse (showString " ; ")
                (map (\(name, rhs) -> showString (name ++ " = ") . rhs) lbinds))
         . showString (close ++ " in ")
         . core
    Nothing ->
      showString "do { "
      . foldr (.) id
          (intersperse (showString " ; ")
             (map (\case MonadBind name rhs -> showString (name ++ " <- ") . rhs
                         LetBind name rhs -> showString ("let { " ++ name ++ " = ") . rhs
                                             . showString " }")
                  binds))
      . showString " ; " . core . showString " }"

data Fixity = Prefix | Infix
  deriving (Show)

operator :: SOp a t -> (Fixity, String)
operator OAdd{} = (Infix, "+")
operator OMul{} = (Infix, "*")
operator ONeg{} = (Prefix, "negate")
operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
operator OIf = (Prefix, "ifB")