{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.AST where

import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..))

import Array
import AST
import CHAD.Types
import Data


type NExpr :: [(Symbol, Ty)] -> Ty -> Type
data NExpr env t where
  -- lambda calculus
  NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
  NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t

  -- environment management
  NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t

  -- base types
  NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
  NEFst :: NExpr env (TPair a b) -> NExpr env a
  NESnd :: NExpr env (TPair a b) -> NExpr env b
  NENil :: NExpr env TNil
  NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
  NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
  NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c

  -- array operations
  NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
  NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
  NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
  NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
  NEUnit :: NExpr env t -> NExpr env (TArr Z t)
  NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
  NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
  NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))

  -- expression operations
  NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
  NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
  NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
  NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
  NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
  NEOp :: SOp a t -> NExpr env a -> NExpr env t

  -- custom derivatives
  NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t  -- ^ regular operation
           -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)  -- ^ CHAD forward pass
           -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)  -- ^ CHAD reverse derivative
           -> NExpr env a -> NExpr env b
           -> NExpr env t

  -- partiality
  NEError :: STy a -> String -> NExpr env a

  -- embedded unnamed expressions
  NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)

type family Lookup name env where
  Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
  Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
  Lookup name ('(name, t) : env) = t
  Lookup name (_ : env) = Lookup name env

type family DropNth i env where
  DropNth Z (_ : env) = env
  DropNth (S i) (p : env) = p : DropNth i env

data Var name t = Var (SSymbol name) (STy t)
  deriving (Show)

instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
  a + b = NEOp (OAdd knownScalTy) (NEPair a b)
  a * b = NEOp (OMul knownScalTy) (NEPair a b)
  negate e = NEOp (ONeg knownScalTy) e
  abs = error "abs undefined for NExpr"
  signum = error "signum undefined for NExpr"
  fromInteger =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty . fromInteger

instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st))
      => Fractional (NExpr env t) where
  recip e = NEOp (ORecip knownScalTy) e
  fromRational =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty . fromRational

instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st))
      => Floating (NExpr env t) where
  pi =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty pi
  exp = NEOp (OExp knownScalTy)
  log = NEOp (OExp knownScalTy)
  sin = undefined ; cos = undefined ; tan = undefined
  asin = undefined ; acos = undefined ; atan = undefined
  sinh = undefined ; cosh = undefined
  asinh = undefined ; acosh = undefined ; atanh = undefined

instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
  fromLabel = Var symbolSing knownTy

instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
  fromLabel = NEVar (fromLabel @name)

-- | Innermost variable variable on the outside, on the right.
data NEnv env where
  NTop :: NEnv '[]
  NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)

-- | First (outermost) parameter on the outside, on the left.
-- * env: environment of this function (grows as you go deeper inside lambdas)
-- * env': environment of the body of the function
-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list
data NFun env env' t where
  NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
  NBody :: NExpr env' t -> NFun env' env' t

type family UnName env where
  UnName '[] = '[]
  UnName ('(name, t) : env) = t : UnName env

envFromNEnv :: NEnv env -> SList STy (UnName env)
envFromNEnv NTop = SNil
envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env

inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t
inlineNFun fun args = NEUnnamed (fromNamed fun) args

fromNamed :: NFun '[] env t -> Ex (UnName env) t
fromNamed = fromNamedFun NTop

-- | Some of the parameters have already been put in the environment; some
-- haven't. Transfer all parameters to the left into the environment.
--
--   [] `fromNamedFun` λx y z. E
-- = []:x `fromNamedFun` λy z. E
-- = []:x:y `fromNamedFun` λz. E
-- = []:x:y:z `fromNamedFun` λ. E
-- = []:x:y:z `fromNamedExpr` E
fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
fromNamedFun env (NBody e) = fromNamedExpr env e

fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
fromNamedExpr val = \case
  NEVar var@(Var _ ty)
    | Just idx <- find var val -> EVar ext ty idx
    | otherwise -> error "Variable out of scope in conversion from surface \
                         \expression to De Bruijn expression"
  NELet n a b -> ELet ext (go a) (lambda val n b)

  NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)

  NEPair a b -> EPair ext (go a) (go b)
  NEFst e -> EFst ext (go e)
  NESnd e -> ESnd ext (go e)
  NENil -> ENil ext
  NEInl t e -> EInl ext t (go e)
  NEInr t e -> EInr ext t (go e)
  NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)

  NEConstArr n t x -> EConstArr ext n t x
  NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
  NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
  NESum1Inner e -> ESum1Inner ext (go e)
  NEUnit e -> EUnit ext (go e)
  NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
  NEMaximum1Inner e -> EMaximum1Inner ext (go e)
  NEMinimum1Inner e -> EMinimum1Inner ext (go e)

  NEConst t x -> EConst ext t x
  NEIdx0 e -> EIdx0 ext (go e)
  NEIdx1 a b -> EIdx1 ext (go a) (go b)
  NEIdx a b -> EIdx ext (go a) (go b)
  NEShape e -> EShape ext (go e)
  NEOp op e -> EOp ext op (go e)

  NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 ->
    ECustom ext ta tb ttape
            (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a)
            (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b)
            (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c)
            (go e1) (go e2)

  NEError t s -> EError ext t s

  NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args
  where
    go :: NExpr env t' -> Ex (UnName env) t'
    go = fromNamedExpr val

    find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
    find _ NTop = Nothing
    find var@(Var s ty) (val' `NPush` Var s' ty')
      | Just Refl <- testEquality s s'
      , Just Refl <- testEquality ty ty'
      = Just IZ
      | otherwise
      = IS <$> find var val'

    lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
    lambda val' var e = fromNamedExpr (val' `NPush` var) e

    lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
    lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e

    injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t
    injectWrapLet e SNil = e
    injectWrapLet e (arg `SCons` args) =
      injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
                    args

dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
dropNth SZ (val `NPush` _) = val
dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
dropNth _ NTop = error "DropNth: index out of range"

dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
dropNthW SZ (_ `NPush` _) = WSink
dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
dropNthW _ NTop = error "DropNth: index out of range"