{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
module AST.Pretty (ppExpr, ppTy) where

import Control.Monad (ap)
import Data.List (intersperse)
import Data.Functor.Const

import AST
import AST.Count
import CHAD.Types
import Data
import ForwardAD.DualNumbers.Types


type SVal = SList (Const String)

newtype M a = M { runM :: Int -> (a, Int) }
  deriving (Functor)
instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) }

genId :: M Int
genId = M (\i -> (i, i + 1))

genName' :: String -> M String
genName' prefix = (prefix ++) . show <$> genId

genName :: M String
genName = genName' "x"

genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn' prefix ty idx ex
  | occCount idx ex == mempty = case ty of STNil -> return "()"
                                           _ -> return "_"
  | otherwise                 = genName' prefix

genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = genNameIfUsedIn' "x"

ppExpr :: SList f env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
  where
    mkVal :: SList f env -> M (SVal env)
    mkVal SNil = return SNil
    mkVal (SCons _ v) = do
      val <- mkVal v
      name <- genName
      return (Const name `SCons` val)

ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
ppExpr' d val = \case
  EVar _ _ i -> return $ showString $ getConst $ slistIdx val i

  e@ELet{} -> ppExprLet d val e

  EPair _ a b -> do
    a' <- ppExpr' 0 val a
    b' <- ppExpr' 0 val b
    return $ showString "(" . a' . showString ", " . b' . showString ")"

  EFst _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "fst " . e'

  ESnd _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "snd " . e'

  ENil _ -> return $ showString "()"

  EInl _ _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Inl " . e'

  EInr _ _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Inr " . e'

  ECase _ e a b -> do
    e' <- ppExpr' 0 val e
    let STEither t1 t2 = typeOf e
    name1 <- genNameIfUsedIn t1 IZ a
    a' <- ppExpr' 0 (Const name1 `SCons` val) a
    name2 <- genNameIfUsedIn t2 IZ b
    b' <- ppExpr' 0 (Const name2 `SCons` val) b
    return $ showParen (d > 0) $
      showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
      . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"

  ENothing _ _ -> return $ showString "nothing"

  EJust _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "Just " . e'

  EMaybe _ a b e -> do
    let STMaybe t = typeOf e
    a' <- ppExpr' 11 val a
    name <- genNameIfUsedIn t IZ b
    b' <- ppExpr' 11 (Const name `SCons` val) b
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $
      showString "maybe " . a' . showString " " . b' . showString " " . e'

  EConstArr _ _ ty v
    | Dict <- scalRepIsShow ty -> return $ showsPrec d v

  EBuild _ n a b -> do
    a' <- ppExpr' 11 val a
    name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
    e' <- ppExpr' 0 (Const name `SCons` val) b
    return $ showParen (d > 10) $
      showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"

  EFold1Inner _ a b c -> do
    name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
    name2 <- genNameIfUsedIn (typeOf a) IZ a
    a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
    b' <- ppExpr' 11 val b
    c' <- ppExpr' 11 val c
    return $ showParen (d > 10) $
      showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
      . showString ") " . b' . showString " " . c'

  ESum1Inner _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "sum1i " . e'

  EUnit _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "unit " . e'

  EReplicate1Inner _ a b -> do
    a' <- ppExpr' 11 val a
    b' <- ppExpr' 11 val b
    return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b'

  EConst _ ty v
    | Dict <- scalRepIsShow ty -> return $ showsPrec d v

  EIdx0 _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "idx0 " . e'

  EIdx1 _ a b -> do
    a' <- ppExpr' 9 val a
    b' <- ppExpr' 9 val b
    return $ showParen (d > 8) $ a' . showString " .! " . b'

  EIdx _ a b -> do
    a' <- ppExpr' 9 val a
    b' <- ppExpr' 10 val b
    return $ showParen (d > 8) $
      a' . showString " ! " . b'

  EShape _ e -> do
    e' <- ppExpr' 11 val e
    return $ showParen (d > 10) $ showString "shape " . e'

  EOp _ op (EPair _ a b)
    | (Infix, ops) <- operator op -> do
        a' <- ppExpr' 9 val a
        b' <- ppExpr' 9 val b
        return $ showParen (d > 8) $ a' . showString (" " ++ ops ++ " ") . b'

  EOp _ op e -> do
    e' <- ppExpr' 11 val e
    let ops = case operator op of
                (Infix, s) -> "(" ++ s ++ ")"
                (Prefix, s) -> s
    return $ showParen (d > 10) $ showString (ops ++ " ") . e'

  ECustom _ t1 t2 a b c e1 e2 -> do
    pn1 <- genNameIfUsedIn t1 (IS IZ) a
    pn2 <- genNameIfUsedIn t2 IZ a
    fn1 <- genNameIfUsedIn t1 (IS IZ) b
    fn2 <- genNameIfUsedIn (dn t2) IZ b
    rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c
    rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c
    rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
    a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a
    b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b
    c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c
    e1' <- ppExpr' 11 val e1
    e2' <- ppExpr' 11 val e2
    return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2'

  EWith e1 e2 -> do
    e1' <- ppExpr' 11 val e1
    name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
    e2' <- ppExpr' 0 (Const name `SCons` val) e2
    return $ showParen (d > 10) $
      showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
      . e2' . showString ")"

  EAccum i e1 e2 e3 -> do
    e1' <- ppExpr' 11 val e1
    e2' <- ppExpr' 11 val e2
    e3' <- ppExpr' 11 val e3
    return $ showParen (d > 10) $
      showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'

  EZero _ -> return $ showString "zero"

  EPlus _ a b -> do
    a' <- ppExpr' 11 val a
    b' <- ppExpr' 11 val b
    return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'

  EOneHot _ i a b -> do
    a' <- ppExpr' 11 val a
    b' <- ppExpr' 11 val b
    return $ showParen (d > 10) $
      showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b'

  EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)

ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
ppExprLet d val etop = do
  let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS)
      collect val' (ELet _ rhs body) = do
        let occ = occCount IZ body
        name <- genNameIfUsedIn (typeOf rhs) IZ body
        rhs' <- ppExpr' 0 val' rhs
        (binds, core) <- collect (Const name `SCons` val') body
        return ((name, occ, rhs') : binds, core)
      collect val' e = ([],) <$> ppExpr' 0 val' e

  (binds, core) <- collect val etop

  let (open, close) = case binds of
        [_] -> ("", "")
        _ -> ("{ ", " }")
  return $ showParen (d > 0) $
    showString ("let " ++ open)
    . foldr (.) id
        (intersperse (showString " ; ")
           (map (\(name, _occ, rhs) ->
                    showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs)
                binds))
    . showString (close ++ " in ")
    . core

data Fixity = Prefix | Infix
  deriving (Show)

operator :: SOp a t -> (Fixity, String)
operator OAdd{} = (Infix, "+")
operator OMul{} = (Infix, "*")
operator ONeg{} = (Prefix, "negate")
operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
operator OAnd = (Infix, "&&")
operator OOr = (Infix, "||")
operator OIf = (Prefix, "ifB")
operator ORound64 = (Prefix, "round")
operator OToFl64 = (Prefix, "toFl64")

ppTy :: Int -> STy t -> String
ppTy d ty = ppTys d ty ""

ppTys :: Int -> STy t -> ShowS
ppTys _ STNil = showString "1"
ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b
ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b
ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t
ppTys d (STArr n t) = showParen (d > 10) $
  showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t
ppTys _ (STScal sty) = showString $ case sty of
  STI32 -> "i32"
  STI64 -> "i64"
  STF32 -> "f32"
  STF64 -> "f64"
  STBool -> "bool"
ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t