{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module AST (module AST, module AST.Weaken) where

import Data.Functor.Const
import Data.Kind (Type)
import Data.Int
import Data.Type.Equality

import AST.Env
import AST.Weaken
import Data


data Ty
  = TNil
  | TPair Ty Ty
  | TEither Ty Ty
  | TArr Nat Ty  -- ^ rank, element type
  | TScal ScalTy
  | TAccum Ty
  deriving (Show, Eq, Ord)

data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
  deriving (Show, Eq, Ord)

type STy :: Ty -> Type
data STy t where
  STNil :: STy TNil
  STPair :: STy a -> STy b -> STy (TPair a b)
  STEither :: STy a -> STy b -> STy (TEither a b)
  STArr :: SNat n -> STy t -> STy (TArr n t)
  STScal :: SScalTy t -> STy (TScal t)
  STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)

instance TestEquality STy where
  testEquality STNil STNil = Just Refl
  testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
  testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
  testEquality _ _ = Nothing

data SScalTy t where
  STI32 :: SScalTy TI32
  STI64 :: SScalTy TI64
  STF32 :: SScalTy TF32
  STF64 :: SScalTy TF64
  STBool :: SScalTy TBool
deriving instance Show (SScalTy t)

instance TestEquality SScalTy where
  testEquality STI32 STI32 = Just Refl
  testEquality STI64 STI64 = Just Refl
  testEquality STF32 STF32 = Just Refl
  testEquality STF64 STF64 = Just Refl
  testEquality STBool STBool = Just Refl
  testEquality _ _ = Nothing

scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
scalRepIsShow STI32 = Dict
scalRepIsShow STI64 = Dict
scalRepIsShow STF32 = Dict
scalRepIsShow STF64 = Dict
scalRepIsShow STBool = Dict

type TIx = TScal TI64

tIx :: STy TIx
tIx = STScal STI64

type family ScalRep t where
  ScalRep TI32 = Int32
  ScalRep TI64 = Int64
  ScalRep TF32 = Float
  ScalRep TF64 = Double
  ScalRep TBool = Bool

-- | This index is flipped around from the usual direction: the smallest index
-- is at the _heart_ of the nesting, not at the outside. The outermost layer
-- indexes into the _outer_ dimension of the type @t@. This makes indices into
-- compound structures work properly with coproducts.
type family AcIdx t i where
  AcIdx t Z = TNil
  AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
  AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
  AcIdx (TArr Z t) (S i) = AcIdx t i
  AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i)

type family AcVal t i where
  AcVal t Z = t
  AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
  AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
  AcVal (TArr Z t) (S i) = AcVal t i
  AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i

-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
  -- lambda calculus
  EVar :: x t -> STy t -> Idx env t -> Expr x env t
  ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t

  -- base types
  EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
  EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
  ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
  ENil :: x TNil -> Expr x env TNil
  EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
  EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
  ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c

  -- array operations
  EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
  EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
  EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
  EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
  -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)  -- TODO: unused

  -- expression operations
  EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
  EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
  EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
  EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
  EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
  EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t

  -- accumulation effect
  EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
  EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
  -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil

  -- partiality
  EError :: STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)

type Ex = Expr (Const ())

ext :: Const () a
ext = Const ()

type family Tup env where
  Tup '[] = TNil
  Tup (t : ts) = TPair (Tup ts) t

tTup :: SList STy env -> STy (Tup env)
tTup SNil = STNil
tTup (SCons t ts) = STPair (tTup ts) t

eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup SNil = ENil ext
eTup (e `SCons` es) = EPair ext (eTup es) e

type family InvTup core env where
  InvTup core '[] = core
  InvTup core (t : ts) = InvTup (TPair core t) ts

type SOp :: Ty -> Ty -> Type
data SOp a t where
  OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
  OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
  ONeg :: SScalTy a -> SOp (TScal a) (TScal a)
  OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  ONot :: SOp (TScal TBool) (TScal TBool)
  OIf :: SOp (TScal TBool) (TEither TNil TNil)
deriving instance Show (SOp a t)

opt2 :: SOp a t -> STy t
opt2 = \case
  OAdd t -> STScal t
  OMul t -> STScal t
  ONeg t -> STScal t
  OLt _ -> STScal STBool
  OLe _ -> STScal STBool
  OEq _ -> STScal STBool
  ONot -> STScal STBool
  OIf -> STEither STNil STNil

typeOf :: Expr x env t -> STy t
typeOf = \case
  EVar _ t _ -> t
  ELet _ _ e -> typeOf e

  EPair _ a b -> STPair (typeOf a) (typeOf b)
  EFst _ e | STPair t _ <- typeOf e -> t
  ESnd _ e | STPair _ t <- typeOf e -> t
  ENil _ -> STNil
  EInl _ t2 e -> STEither (typeOf e) t2
  EInr _ t1 e -> STEither t1 (typeOf e)
  ECase _ _ a _ -> typeOf a

  EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
  EBuild _ n _ e -> STArr n (typeOf e)
  EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
  EUnit _ e -> STArr SZ (typeOf e)
  -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t

  EConst _ t _ -> STScal t
  EIdx0 _ e | STArr _ t <- typeOf e -> t
  EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
  EIdx _ _ e _ | STArr _ t <- typeOf e -> t
  EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
  EOp _ op _ -> opt2 op

  EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
  EAccum _ _ _ _ -> STNil

  EError t _ -> t

unSNat :: SNat n -> Nat
unSNat SZ = Z
unSNat (SS n) = S (unSNat n)

unSTy :: STy t -> Ty
unSTy = \case
  STNil -> TNil
  STPair a b -> TPair (unSTy a) (unSTy b)
  STEither a b -> TEither (unSTy a) (unSTy b)
  STArr n t -> TArr (unSNat n) (unSTy t)
  STScal t -> TScal (unSScalTy t)
  STAccum t -> TAccum (unSTy t)

unSList :: SList STy env -> [Ty]
unSList SNil = []
unSList (SCons t l) = unSTy t : unSList l

unSScalTy :: SScalTy t -> ScalTy
unSScalTy = \case
  STI32 -> TI32
  STI64 -> TI64
  STF32 -> TF32
  STF64 -> TF64
  STBool -> TBool

subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
                                    IS i -> EVar x t i

subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
      -> Expr x env t -> Expr x env' t
subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId

subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
       -> env' :> envOut
       -> Expr x env t
       -> Expr x envOut t
subst' f w = \case
  EVar x t i -> f x t w i
  ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
  EPair x a b -> EPair x (subst' f w a) (subst' f w b)
  EFst x e -> EFst x (subst' f w e)
  ESnd x e -> ESnd x (subst' f w e)
  ENil x -> ENil x
  EInl x t e -> EInl x t (subst' f w e)
  EInr x t e -> EInr x t (subst' f w e)
  ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
  EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
  EUnit x e -> EUnit x (subst' f w e)
  -- EReplicate x e -> EReplicate x (subst' f w e)
  EConst x t v -> EConst x t v
  EIdx0 x e -> EIdx0 x (subst' f w e)
  EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
  EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es)
  EShape x e -> EShape x (subst' f w e)
  EOp x op e -> EOp x op (subst' f w e)
  EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
  EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
  EError t s -> EError t s
  where
    sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
          -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
    sinkF f' x' t w' = \case
      IZ -> EVar x' t (w' @> IZ)
      IS i -> f' x' t (WPop w') i

weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))

wUndoSubenv :: Subenv env env' -> env' :> env
wUndoSubenv SETop = WId
wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub

slistIdx :: SList f list -> Idx list t -> f t
slistIdx (SCons x _) IZ = x
slistIdx (SCons _ list) (IS i) = slistIdx list i
slistIdx SNil i = case i of {}

idx2int :: Idx env t -> Int
idx2int IZ = 0
idx2int (IS n) = 1 + idx2int n

class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64
instance KnownScalTy TF32 where knownScalTy = STF32
instance KnownScalTy TF64 where knownScalTy = STF64
instance KnownScalTy TBool where knownScalTy = STBool

class KnownTy t where knownTy :: STy t
instance KnownTy TNil where knownTy = STNil
instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy

class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv

styKnown :: STy t -> Dict (KnownTy t)
styKnown STNil = Dict
styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- styKnown t = Dict

sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
sscaltyKnown STI32 = Dict
sscaltyKnown STI64 = Dict
sscaltyKnown STF32 = Dict
sscaltyKnown STF64 = Dict
sscaltyKnown STBool = Dict

ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
ebuildUp1 n sh size f =
  EBuild ext (SS n) (EPair ext sh size) $
    let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
    in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
                  (EFst ext arg)