{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module AST (module AST, module AST.Types, module AST.Weaken) where

import Data.Functor.Const
import Data.Kind (Type)

import Array
import AST.Types
import AST.Weaken
import CHAD.Types
import Data
import ForwardAD.DualNumbers.Types


-- | This index is flipped around from the usual direction: the smallest index
-- is at the _heart_ of the nesting, not at the outside. The outermost layer
-- indexes into the _outer_ dimension of the type @t@. This makes indices into
-- compound structures work properly with coproducts.
type family AcIdx t i where
  AcIdx t Z = TNil
  AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
  AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
  AcIdx (TMaybe t) (S i) = AcIdx t i
  AcIdx (TArr Z t) (S i) = AcIdx t i
  AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i)

type family AcVal t i where
  AcVal t Z = t
  AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
  AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
  AcVal (TMaybe t) (S i) = AcVal t i
  AcVal (TArr n t) (S i) = TPair (Tup (Replicate n TIx)) (AcValArr n t (S i))

type family AcValArr n t i where
  AcValArr n t Z = TArr n t
  AcValArr Z t (S i) = AcVal t i
  AcValArr (S n) t (S i) = AcValArr n t i

-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
--
-- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the
-- type transformation of CHAD. Indeed, these constructors are created _by_
-- CHAD, and are intended to be eliminated after simplification, so that the
-- input program as well as the output program do not contain these
-- constructors.
-- TODO: ensure this by a "stage" type parameter.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
  -- lambda calculus
  EVar :: x t -> STy t -> Idx env t -> Expr x env t
  ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t

  -- base types
  EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
  EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
  ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
  ENil :: x TNil -> Expr x env TNil
  EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
  EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
  ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
  ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
  EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
  EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b

  -- array operations
  EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
  EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
  EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
  ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
  EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
  EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)

  -- expression operations
  EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
  EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
  EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
  EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
  EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
  EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t

  -- custom derivatives
  ECustom :: x t -> STy a -> STy b
          -> Expr x '[b, a] t  -- ^ regular operation
          -> Expr x '[DN b, a] (DN t)  -- ^ dual-numbers forward derivative
          -> Expr x '[D2 t, D1 b, D1 a] (D2 b)  -- ^ CHAD reverse derivative
          -> Expr x env a -> Expr x env b
          -> Expr x env t

  -- accumulation effect
  EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
  EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
  -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil

  -- monoidal operations (to be desugared to regular operations after simplification)
  EZero :: STy t -> Expr x env (D2 t)
  EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
  EOneHot :: STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t)

  -- partiality
  EError :: STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)

type Ex = Expr (Const ())

ext :: Const () a
ext = Const ()

type family Tup env where
  Tup '[] = TNil
  Tup (t : ts) = TPair (Tup ts) t

mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
      -> SList f list -> f (Tup list)
mkTup nil _    SNil = nil
mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e

tTup :: SList STy env -> STy (Tup env)
tTup = mkTup STNil STPair

eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup = mkTup (ENil ext) (EPair ext)

unTup :: (forall a b. c (TPair a b) -> (c a, c b))
      -> SList f list -> c (Tup list) -> SList c list
unTup _ SNil _ = SNil
unTup unpack (_ `SCons` list) tup =
  let (xs, x) = unpack tup
  in x `SCons` unTup unpack list xs

type family InvTup core env where
  InvTup core '[] = core
  InvTup core (t : ts) = InvTup (TPair core t) ts

type SOp :: Ty -> Ty -> Type
data SOp a t where
  OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
  OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
  ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
  OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
  ONot :: SOp (TScal TBool) (TScal TBool)
  OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
  OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
  OIf :: SOp (TScal TBool) (TEither TNil TNil)  -- True is Left, False is Right
  ORound64 :: SOp (TScal TF64) (TScal TI64)
  OToFl64 :: SOp (TScal TI64) (TScal TF64)
deriving instance Show (SOp a t)

opt2 :: SOp a t -> STy t
opt2 = \case
  OAdd t -> STScal t
  OMul t -> STScal t
  ONeg t -> STScal t
  OLt _ -> STScal STBool
  OLe _ -> STScal STBool
  OEq _ -> STScal STBool
  ONot -> STScal STBool
  OAnd -> STScal STBool
  OOr -> STScal STBool
  OIf -> STEither STNil STNil
  ORound64 -> STScal STI64
  OToFl64 -> STScal STF64

typeOf :: Expr x env t -> STy t
typeOf = \case
  EVar _ t _ -> t
  ELet _ _ e -> typeOf e

  EPair _ a b -> STPair (typeOf a) (typeOf b)
  EFst _ e | STPair t _ <- typeOf e -> t
  ESnd _ e | STPair _ t <- typeOf e -> t
  ENil _ -> STNil
  EInl _ t2 e -> STEither (typeOf e) t2
  EInr _ t1 e -> STEither t1 (typeOf e)
  ECase _ _ a _ -> typeOf a
  ENothing _ t -> STMaybe t
  EJust _ e -> STMaybe (typeOf e)
  EMaybe _ e _ _ -> typeOf e

  EConstArr _ n t _ -> STArr n (STScal t)
  EBuild _ n _ e -> STArr n (typeOf e)
  EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
  ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
  EUnit _ e -> STArr SZ (typeOf e)
  EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t

  EConst _ t _ -> STScal t
  EIdx0 _ e | STArr _ t <- typeOf e -> t
  EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
  EIdx _ e _ | STArr _ t <- typeOf e -> t
  EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
  EOp _ op _ -> opt2 op

  ECustom _ _ _ e _ _ _ _ -> typeOf e

  EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
  EAccum _ _ _ _ -> STNil

  EZero t -> d2 t
  EPlus t _ _ -> d2 t
  EOneHot t _ _ _ -> d2 t

  EError t _ -> t

-- unSNat :: SNat n -> Nat
-- unSNat SZ = Z
-- unSNat (SS n) = S (unSNat n)

-- unSTy :: STy t -> Ty
-- unSTy = \case
--   STNil -> TNil
--   STPair a b -> TPair (unSTy a) (unSTy b)
--   STEither a b -> TEither (unSTy a) (unSTy b)
--   STMaybe t -> TMaybe (unSTy t)
--   STArr n t -> TArr (unSNat n) (unSTy t)
--   STScal t -> TScal (unSScalTy t)
--   STAccum t -> TAccum (unSTy t)

-- unSEnv :: SList STy env -> [Ty]
-- unSEnv SNil = []
-- unSEnv (SCons t l) = unSTy t : unSEnv l

unSScalTy :: SScalTy t -> ScalTy
unSScalTy = \case
  STI32 -> TI32
  STI64 -> TI64
  STF32 -> TF32
  STF64 -> TF64
  STBool -> TBool

subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
                                    IS i -> EVar x t i

subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
      -> Expr x env t -> Expr x env' t
subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId

subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
       -> env' :> envOut
       -> Expr x env t
       -> Expr x envOut t
subst' f w = \case
  EVar x t i -> f x t w i
  ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
  EPair x a b -> EPair x (subst' f w a) (subst' f w b)
  EFst x e -> EFst x (subst' f w e)
  ESnd x e -> ESnd x (subst' f w e)
  ENil x -> ENil x
  EInl x t e -> EInl x t (subst' f w e)
  EInr x t e -> EInr x t (subst' f w e)
  ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
  ENothing x t -> ENothing x t
  EJust x e -> EJust x (subst' f w e)
  EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
  EConstArr x n t a -> EConstArr x n t a
  EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
  EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
  ESum1Inner x e -> ESum1Inner x (subst' f w e)
  EUnit x e -> EUnit x (subst' f w e)
  EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
  EConst x t v -> EConst x t v
  EIdx0 x e -> EIdx0 x (subst' f w e)
  EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
  EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
  EShape x e -> EShape x (subst' f w e)
  EOp x op e -> EOp x op (subst' f w e)
  ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2)
  EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
  EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
  EZero t -> EZero t
  EPlus t a b -> EPlus t (subst' f w a) (subst' f w b)
  EOneHot t i a b -> EOneHot t i (subst' f w a) (subst' f w b)
  EError t s -> EError t s
  where
    sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
          -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
    sinkF f' x' t w' = \case
      IZ -> EVar x' t (w' @> IZ)
      IS i -> f' x' t (WPop w') i

weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))

slistIdx :: SList f list -> Idx list t -> f t
slistIdx (SCons x _) IZ = x
slistIdx (SCons _ list) (IS i) = slistIdx list i
slistIdx SNil i = case i of {}

idx2int :: Idx env t -> Int
idx2int IZ = 0
idx2int (IS n) = 1 + idx2int n

class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64
instance KnownScalTy TF32 where knownScalTy = STF32
instance KnownScalTy TF64 where knownScalTy = STF64
instance KnownScalTy TBool where knownScalTy = STBool

class KnownTy t where knownTy :: STy t
instance KnownTy TNil where knownTy = STNil
instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy

class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv

styKnown :: STy t -> Dict (KnownTy t)
styKnown STNil = Dict
styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- styKnown t = Dict

sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
sscaltyKnown STI32 = Dict
sscaltyKnown STI64 = Dict
sscaltyKnown STF32 = Dict
sscaltyKnown STF64 = Dict
sscaltyKnown STBool = Dict

envKnown :: SList STy env -> Dict (KnownEnv env)
envKnown SNil = Dict
envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict

ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
ebuildUp1 n sh size f =
  EBuild ext (SS n) (EPair ext sh size) $
    let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
    in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
                (EFst ext arg)

eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
eidxEq SZ _ _ = EConst ext STBool True
eidxEq (SS n) a b
  | let ty = tTup (sreplicate (SS n) tIx)
  = ELet ext a $
    ELet ext (weakenExpr WSink b) $
      EOp ext OAnd $ EPair ext
        (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ)))
                                        (ESnd ext (EVar ext ty IZ))))
        (eidxEq n (EFst ext (EVar ext ty (IS IZ)))
                  (EFst ext (EVar ext ty IZ)))

emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
emap f arr =
  let STArr n t = typeOf arr
  in ELet ext arr $
       EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
         ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
                            (EVar ext (tTup (sreplicate n tIx)) IZ)) $
           weakenExpr (WCopy (WSink .> WSink)) f

ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
ezip a b =
  let STArr n t1 = typeOf a
      STArr _ t2 = typeOf b
  in ELet ext a $
     ELet ext (weakenExpr WSink b) $
       EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
         EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
                             (EVar ext (tTup (sreplicate n tIx)) IZ))
                   (EIdx ext (EVar ext (STArr n t2) (IS IZ))
                             (EVar ext (tTup (sreplicate n tIx)) IZ))

arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n)
arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext)
  where
    -- symbolic version of 'invert' in Interpreter
    go :: forall n m t env proxy. proxy t -> SNat n -> SNat m
       -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m))
    go _ SZ _ _ acidx = acidx
    go p (SS n) m idx acidx
      | Refl <- lemPlusSuccRight @n @m
      = ELet ext idx $
          go p n (SS m)
             (EFst ext (EVar ext (typeOf idx) IZ))
             (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ))
                        (weakenExpr WSink acidx))

lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t
lemAcValArrN _ SZ = Refl
lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl