{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE EmptyCase #-}
module AST (module AST, module AST.Weaken) where

import Data.Functor.Const

import Data.Kind (Type)
import Data.Int

import AST.Weaken
import Data


data Ty
  = TNil
  | TPair Ty Ty
  | TEither Ty Ty
  | TArr Nat Ty  -- ^ rank, element type
  | TScal ScalTy
  | TEVM [Ty] Ty
  deriving (Show, Eq, Ord)

data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
  deriving (Show, Eq, Ord)

type STy :: Ty -> Type
data STy t where
  STNil :: STy TNil
  STPair :: STy a -> STy b -> STy (TPair a b)
  STEither :: STy a -> STy b -> STy (TEither a b)
  STArr :: SNat n -> STy t -> STy (TArr n t)
  STScal :: SScalTy t -> STy (TScal t)
  STEVM :: SList STy env -> STy t -> STy (TEVM env t)
deriving instance Show (STy t)

data SScalTy t where
  STI32 :: SScalTy TI32
  STI64 :: SScalTy TI64
  STF32 :: SScalTy TF32
  STF64 :: SScalTy TF64
  STBool :: SScalTy TBool
deriving instance Show (SScalTy t)

type TIx = TScal TI64

type family ScalRep t where
  ScalRep TI32 = Int32
  ScalRep TI64 = Int64
  ScalRep TF32 = Float
  ScalRep TF64 = Double
  ScalRep TBool = Bool

type ConsN :: Nat -> a -> [a] -> [a]
type family ConsN n x l where
  ConsN Z x l = l
  ConsN (S n) x l = x : ConsN n x l

-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
  -- lambda calculus
  EVar :: x t -> STy t -> Idx env t -> Expr x env t
  ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t

  -- base types
  EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
  EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
  ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
  ENil :: x TNil -> Expr x env TNil
  EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
  EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
  ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c

  -- array operations
  EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
  EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
  EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)

  -- expression operations
  EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
  EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
  EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
  EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t

  -- EVM operations
  EMOne :: SList STy venv -> Idx venv t -> Expr x env t -> Expr x env (TEVM venv TNil)
  EMScope :: Expr x env (TEVM (t : venv) a) -> Expr x env (TEVM venv (TPair a t))
  EMReturn :: SList STy venv -> Expr x env t -> Expr x env (TEVM venv t)
  EMBind :: Expr x env (TEVM venv a) -> Expr x (a : env) (TEVM venv b) -> Expr x env (TEVM venv b)

  -- partiality
  EError :: STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)

type Ex = Expr (Const ())

ext :: Const () a
ext = Const ()

type SOp :: Ty -> Ty -> Type
data SOp a t where
  OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
  OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
  ONeg :: SScalTy a -> SOp (TScal a) (TScal a)
  OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  ONot :: SOp (TScal TBool) (TScal TBool)
  OIf :: SOp (TScal TBool) (TEither TNil TNil)
deriving instance Show (SOp a t)

opt2 :: SOp a t -> STy t
opt2 = \case
  OAdd t -> STScal t
  OMul t -> STScal t
  ONeg t -> STScal t
  OLt _ -> STScal STBool
  OLe _ -> STScal STBool
  OEq _ -> STScal STBool
  ONot -> STScal STBool
  OIf -> STEither STNil STNil

typeOf :: Expr x env t -> STy t
typeOf = \case
  EVar _ t _ -> t
  ELet _ _ e -> typeOf e

  EPair _ a b -> STPair (typeOf a) (typeOf b)
  EFst _ e | STPair t _ <- typeOf e -> t
  ESnd _ e | STPair _ t <- typeOf e -> t
  ENil _ -> STNil
  EInl _ t2 e -> STEither (typeOf e) t2
  EInr _ t1 e -> STEither t1 (typeOf e)
  ECase _ _ a _ -> typeOf a

  EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
  EBuild _ es e -> STArr (vecLength es) (typeOf e)
  EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t

  EConst _ t _ -> STScal t
  EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
  EIdx _ e _ | STArr _ t <- typeOf e -> t
  EOp _ op _ -> opt2 op

  EMOne t _ _ -> STEVM t STNil
  EMScope e | STEVM (SCons t env) a <- typeOf e -> STEVM env (STPair a t)
  EMReturn env e -> STEVM env (typeOf e)
  EMBind _ e -> typeOf e

  EError t _ -> t

unSNat :: SNat n -> Nat
unSNat SZ = Z
unSNat (SS n) = S (unSNat n)

unSTy :: STy t -> Ty
unSTy = \case
  STNil -> TNil
  STPair a b -> TPair (unSTy a) (unSTy b)
  STEither a b -> TEither (unSTy a) (unSTy b)
  STArr n t -> TArr (unSNat n) (unSTy t)
  STScal t -> TScal (unSScalTy t)
  STEVM l t -> TEVM (unSList l) (unSTy t)

unSList :: SList STy env -> [Ty]
unSList SNil = []
unSList (SCons t l) = unSTy t : unSList l

unSScalTy :: SScalTy t -> ScalTy
unSScalTy = \case
  STI32 -> TI32
  STI64 -> TI64
  STF32 -> TF32
  STF64 -> TF64
  STBool -> TBool

weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr w = \case
  EVar x t i -> EVar x t (w @> i)
  ELet x rhs body -> ELet x (weakenExpr w rhs) (weakenExpr (WCopy w) body)
  EPair x e1 e2 -> EPair x (weakenExpr w e1) (weakenExpr w e2)
  EFst x e -> EFst x (weakenExpr w e)
  ESnd x e -> ESnd x (weakenExpr w e)
  ENil x -> ENil x
  EInl x t e -> EInl x t (weakenExpr w e)
  EInr x t e -> EInr x t (weakenExpr w e)
  ECase x e1 e2 e3 -> ECase x (weakenExpr w e1) (weakenExpr (WCopy w) e2) (weakenExpr (WCopy w) e3)
  EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2)
  EBuild x es e -> EBuild x (weakenExpr w <$> es) (weakenExpr (wcopyN (vecLength es) w) e)
  EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2)
  EConst x t v -> EConst x t v
  EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2)
  EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenExpr w <$> es)
  EOp x op e -> EOp x op (weakenExpr w e)
  EMOne t i e -> EMOne t i (weakenExpr w e)
  EMScope e -> EMScope (weakenExpr w e)
  EMReturn t e -> EMReturn t (weakenExpr w e)
  EMBind e1 e2 -> EMBind (weakenExpr w e1) (weakenExpr (WCopy w) e2)
  EError t s -> EError t s

wsinkN :: SNat n -> env :> ConsN n TIx env
wsinkN SZ = WId
wsinkN (SS n) = WSink .> wsinkN n

wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env'
wcopyN SZ w = w
wcopyN (SS n) w = WCopy (wcopyN n w)

wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env'
wpopN SZ w = w
wpopN (SS n) w = wpopN n (WPop w)

slistIdx :: SList f list -> Idx list t -> f t
slistIdx (SCons x _) IZ = x
slistIdx (SCons _ list) (IS i) = slistIdx list i
slistIdx SNil i = case i of {}

idx2int :: Idx env t -> Int
idx2int IZ = 0
idx2int (IS n) = 1 + idx2int n