{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
module AST.Pretty where

import Control.Monad (ap)
import Data.List (intersperse)
import Data.Foldable (toList)
import Data.Functor.Const

import AST


data Val f env where
  VTop :: Val f '[]
  VPush :: f t -> Val f env -> Val f (t : env)

type SVal = Val (Const String)

valprj :: Val f env -> Idx env t -> f t
valprj (VPush x _) IZ = x
valprj (VPush _ env) (IS i) = valprj env i
valprj VTop i = case i of {}

vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env)
vpushN VNil v = v
vpushN (name :< names) v = VPush (Const name) (vpushN names v)

newtype M a = M { runM :: Int -> (a, Int) }
  deriving (Functor)
instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) }

genId :: M Int
genId = M (\i -> (i, i + 1))

genName :: M String
genName = ('x' :) . show <$> genId

ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
  where
    mkVal :: SList STy env -> M (SVal env)
    mkVal SNil = return VTop
    mkVal (SCons _ v) = do
      val <- mkVal v
      name <- genName
      return (VPush (Const name) val)

ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
ppExpr' d val = \case
  EVar _ _ i -> return $ showString $ getConst $ valprj val i

  etop@ELet{} -> do
    let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
        collect val' (ELet _ rhs body) = do
          name <- genName
          (binds, core) <- collect (VPush (Const name) val') body
          rhs' <- ppExpr' 0 val' rhs
          return ((name, rhs') : binds, core)
        collect val' e = ([],) <$> ppExpr' 0 val' e

    (binds, core) <- collect val etop
    let (open, close) = case binds of
          [_] -> ("{ ", " }")
          _ -> ("", "")
    return $ showParen (d > 0) $
      showString ("let " ++ open)
      . foldr (.) id (intersperse (showString " ; ")
                        (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds))
      . showString (close ++ " in ")
      . core

  EPair _ a b -> do
    a' <- ppExpr' 0 val a
    b' <- ppExpr' 0 val b
    return $ showString "(" . a' . showString ", " . b' . showString ")"

  EFst _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "fst " . e'

  ESnd _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "snd " . e'

  ENil _ -> return $ showString "()"

  EInl _ _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Inl " . e'

  EInr _ _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Inr " . e'

  ECase _ e a b -> do
    e' <- ppExpr' 0 val e
    name1 <- genName
    a' <- ppExpr' 0 (VPush (Const name1) val) a
    name2 <- genName
    b' <- ppExpr' 0 (VPush (Const name2) val) b
    return $ showParen (d > 0) $
      showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
      . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"

  EBuild1 _ a b -> do
    a' <- ppExpr' 11 val a
    name <- genName
    b' <- ppExpr' 0 (VPush (Const name) val) b
    return $ showParen (d > 10) $
      showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"

  EBuild _ es e -> do
    es' <- mapM (ppExpr' 0 val) es
    names <- mapM (const genName) es
    e' <- ppExpr' 0 (vpushN names val) e
    return $ showParen (d > 10) $
      showString "build ["
      . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
      . showString "] (\\["
      . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))
      . showString ("] -> ") . e' . showString ")"

  EFold1 _ a b -> do
    name1 <- genName
    name2 <- genName
    a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a
    b' <- ppExpr' 11 val b
    return $ showParen (d > 10) $
      showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
      . showString ") " . b'

  EConst _ ty v -> return $ showString $ case ty of
    STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v

  EIdx1 _ a b -> do
    a' <- ppExpr' 9 val a
    b' <- ppExpr' 9 val b
    return $ showParen (d > 8) $ a' . showString " ! " . b'

  EIdx _ e es -> do
    e' <- ppExpr' 9 val e
    es' <- traverse (ppExpr' 0 val) es
    return $ showParen (d > 8) $
      e' . showString " ! "
      . showString "["
      . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
      . showString "]"

  EOp _ op (EPair _ a b)
    | (Infix, ops) <- operator op -> do
        a' <- ppExpr' 9 val a
        b' <- ppExpr' 9 val b
        return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b'

  EOp _ op e -> do
    e' <- ppExpr' 11 val e
    let ops = case operator op of
                (Infix, s) -> "(" ++ s ++ ")"
                (Prefix, s) -> s
    return $ showParen (d > 10) $ showString (ops ++ " ") . e'

  EMOne venv i e -> do
    let venvlen = length (unSList venv)
        varname = 'v' : show (venvlen - idx2int i)
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $
      showString ("one " ++ show varname ++ " ") . e'

  EMScope e -> do
    let venv = case typeOf e of STEVM v _ -> v
        venvlen = length (unSList venv)
        varname = 'v' : show venvlen
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $
      showString ("scope " ++ show varname ++ " ") . e'

  EMReturn _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString ("return ") . e'

  etop@(EMBind _ EMBind{}) -> do
    let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
        collect val' (EMBind lhs cont) = do
          name <- genName
          (binds, core) <- collect (VPush (Const name) val') cont
          lhs' <- ppExpr' 0 val' lhs
          return ((name, lhs') : binds, core)
        collect val' e = ([],) <$> ppExpr' 0 val' e

    (binds, core) <- collect val etop
    return $ showParen (d > 0) $
      showString "do { "
      . foldr (.) id (intersperse (showString " ; ")
                        (map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds))
      . showString " ; " . core . showString " }"

  EMBind a b -> do
    a' <- ppExpr' 0 val a
    name <- genName
    b' <- ppExpr' 0 (VPush (Const name) val) b
    return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'

  EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)

data Fixity = Prefix | Infix
  deriving (Show)

operator :: SOp a t -> (Fixity, String)
operator OAdd{} = (Infix, "+")
operator OMul{} = (Infix, "*")
operator ONeg{} = (Prefix, "negate")
operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")