{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.AST where

import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text))

import Array
import AST
import Data


type NExpr :: [(Symbol, Ty)] -> Ty -> Type
data NExpr env t where
  -- lambda calculus
  NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
  NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t

  -- base types
  NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
  NEFst :: NExpr env (TPair a b) -> NExpr env a
  NESnd :: NExpr env (TPair a b) -> NExpr env b
  NENil :: NExpr env TNil
  NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b)
  NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b)
  NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c

  -- array operations
  NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
  NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t)
  NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
  NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
  NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
  NEUnit :: NExpr env t -> NExpr env (TArr Z t)
  NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)

  -- expression operations
  NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
  NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
  NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
  NEIdx :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
  NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
  NEOp :: SOp a t -> NExpr env a -> NExpr env t

  -- partiality
  NEError :: STy a -> String -> NExpr env a
deriving instance Show (NExpr env t)

type family Lookup name env where
  Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
  Lookup name ('(name, t) : env) = t
  Lookup name (_ : env) = Lookup name env

data Var name t = Var (SSymbol name) (STy t)
  deriving (Show)

instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
  a + b = NEOp (OAdd knownScalTy) (NEPair a b)
  a * b = NEOp (OMul knownScalTy) (NEPair a b)
  negate e = NEOp (ONeg knownScalTy) e
  abs = error "abs undefined for NExpr"
  signum = error "signum undefined for NExpr"
  fromInteger =
    let ty = knownScalTy
    in case scalRepIsShow ty of
         Dict -> NEConst ty . fromInteger

instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where
  fromLabel = Var symbolSing knownTy

instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where
  fromLabel = NEVar (fromLabel @name)

data NEnv env where
  NTop :: NEnv '[]
  NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env)

data NFun env env' t where
  NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
  NBody :: NExpr env t -> NFun env env t

type family UnName env where
  UnName '[] = '[]
  UnName ('(name, t) : env) = t : UnName env

fromNamed :: NFun '[] env t -> Ex (UnName env) t
fromNamed = fromNamedFun NTop

fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t
fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun
fromNamedFun env (NBody e) = fromNamedExpr env e

fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t
fromNamedExpr val = \case
  NEVar var@(Var _ ty)
    | Just idx <- find var val -> EVar ext ty idx
    | otherwise -> error "Variable out of scope in conversion from surface \
                         \expression to De Bruijn expression"
  NELet n a b -> ELet ext (go a) (lambda val n b)

  NEPair a b -> EPair ext (go a) (go b)
  NEFst e -> EFst ext (go e)
  NESnd e -> ESnd ext (go e)
  NENil -> ENil ext
  NEInl t e -> EInl ext t (go e)
  NEInr t e -> EInr ext t (go e)
  NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)

  NEConstArr n t x -> EConstArr ext n t x
  NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b)
  NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
  NEFold1Inner n1 n2 a b -> EFold1Inner ext (lambda2 val n1 n2 a) (go b)
  NESum1Inner e -> ESum1Inner ext (go e)
  NEUnit e -> EUnit ext (go e)
  NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)

  NEConst t x -> EConst ext t x
  NEIdx0 e -> EIdx0 ext (go e)
  NEIdx1 a b -> EIdx1 ext (go a) (go b)
  NEIdx n a b -> EIdx ext n (go a) (go b)
  NEShape e -> EShape ext (go e)
  NEOp op e -> EOp ext op (go e)

  NEError t s -> EError t s
  where
    go :: NExpr env t' -> Ex (UnName env) t'
    go = fromNamedExpr val

    find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t')
    find _ NTop = Nothing
    find var@(Var s ty) (val' `NPush` Var s' ty')
      | Just Refl <- testEquality s s'
      , Just Refl <- testEquality ty ty'
      = Just IZ
      | otherwise
      = IS <$> find var val'

    lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b
    lambda val' var e = fromNamedExpr (val' `NPush` var) e

    lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c
    lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e