{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- I want to bring various type variables in scope using type annotations in
-- patterns, but I don't want to have to mention all the other type parameters
-- of the types in question as well then. Partial type signatures (with '_') are
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
module ForwardAD.DualNumbers (
  dfwdDN,
  DN, DNS, DNE, dn, dne,
) where

import AST
import Data
import ForwardAD.DualNumbers.Types


dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
dnPreservesTupIx SZ = Refl
dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl

convIdx :: Idx env t -> Idx (DNE env) (DN t)
convIdx IZ = IZ
convIdx (IS i) = IS (convIdx i)

scalTyCase :: SScalTy t
           -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
           -> (DN (TScal t) ~ TScal t => r)
           -> r
scalTyCase STF32  k1 _ = k1
scalTyCase STF64  k1 _ = k1
scalTyCase STI32  _ k2 = k2
scalTyCase STI64  _ k2 = k2
scalTyCase STBool _ k2 = k2

floatingDual :: ScalIsFloating t ~ True
             => SScalTy t
             -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r
floatingDual STF32 k = k
floatingDual STF64 k = k

-- | Argument does not need to be duplicable.
dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
dop = \case
  OAdd t -> scalTyCase t
    (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
    (EOp ext (OAdd t))
  OMul t -> scalTyCase t
    (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
    (EOp ext (OMul t))
  ONeg t -> scalTyCase t
    (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
    (EOp ext (ONeg t))
  OLt t -> scalTyCase t
    (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
    (EOp ext (OLt t))
  OLe t -> scalTyCase t
    (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
    (EOp ext (OLe t))
  OEq t -> scalTyCase t
    (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
    (EOp ext (OEq t))
  ONot -> EOp ext ONot
  OAnd -> EOp ext OAnd
  OOr -> EOp ext OOr
  OIf -> EOp ext OIf
  ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
  OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
  ORecip t -> floatingDual t $ unFloat (\(x, dx) ->
                  EPair ext (recip' t x)
                            (mul t (neg t (recip' t (mul t x x))) dx))
  OExp t -> floatingDual t $ unFloat (\(x, dx) ->
                EPair ext (EOp ext (OExp t) x) (EOp ext (OExp t) dx))
  OLog t -> floatingDual t $ unFloat (\(x, dx) ->
                EPair ext (EOp ext (OLog t) x)
                          (mul t (recip' t x) dx))
  OIDiv t -> scalTyCase t
    (case t of {})
    (EOp ext (OIDiv t))
  where
    add :: ScalIsNumeric t ~ True
        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
    add t a b = EOp ext (OAdd t) (EPair ext a b)

    mul :: ScalIsNumeric t ~ True
        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
    mul t a b = EOp ext (OMul t) (EPair ext a b)

    neg :: ScalIsNumeric t ~ True
        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
    neg t = EOp ext (ONeg t)

    recip' :: ScalIsFloating t ~ True
          => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
    recip' t = EOp ext (ORecip t)

    unFloat :: DN a ~ TPair a a
            => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
            -> Ex env (DN a) -> Ex env (DN b)
    unFloat f e =
      ELet ext e $
        let var = EVar ext (typeOf e) IZ
        in f (EFst ext var, ESnd ext var)

    binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
             => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
             -> Ex env (DN a) -> Ex env (DN b)
    binFloat f e =
      ELet ext e $
        let var = EVar ext (typeOf e) IZ
        in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
             (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))

zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
zeroScalarConst STI32 = EConst ext STI32 0
zeroScalarConst STI64 = EConst ext STI64 0
zeroScalarConst STF32 = EConst ext STF32 0.0
zeroScalarConst STF64 = EConst ext STF64 0.0

dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
  EVar _ t i -> EVar ext (dn t) (convIdx i)
  ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
  EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
  EFst _ e -> EFst ext (dfwdDN e)
  ESnd _ e -> ESnd ext (dfwdDN e)
  ENil _ -> ENil ext
  EInl _ t e -> EInl ext (dn t) (dfwdDN e)
  EInr _ t e -> EInr ext (dn t) (dfwdDN e)
  ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
  ENothing _ t -> ENothing ext (dn t)
  EJust _ e -> EJust ext (dfwdDN e)
  EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
  EConstArr _ n t x -> scalTyCase t
    (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
          (EConstArr ext n t x))
    (EConstArr ext n t x)
  EBuild _ n a b
    | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
  EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
  ESum1Inner _ e ->
    let STArr n (STScal t) = typeOf e
        pairty = (STPair (STScal t) (STScal t))
    in scalTyCase t
         (ELet ext (dfwdDN e) $
            ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
                                       (EVar ext (STArr n pairty) IZ)))
                 (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
                                       (EVar ext (STArr n pairty) IZ))))
         (ESum1Inner ext (dfwdDN e))
  EUnit _ e -> EUnit ext (dfwdDN e)
  EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
  EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
  EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
  EConst _ t x -> scalTyCase t
    (EPair ext (EConst ext t x) (EConst ext t 0.0))
    (EConst ext t x)
  EIdx0 _ e -> EIdx0 ext (dfwdDN e)
  EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
  EIdx _ a b
    | STArr n _ <- typeOf a
    , Refl <- dnPreservesTupIx n
    -> EIdx ext (dfwdDN a) (dfwdDN b)
  EShape _ e
    | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
    -> EShape ext (dfwdDN e)
  EOp _ op e -> dop op (dfwdDN e)
  ECustom _ _ _ _ pr _ _ e1 e2 ->
    ELet ext (dfwdDN e1) $
    ELet ext (weakenExpr WSink (dfwdDN e2)) $
      weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
  EError t s -> EError (dn t) s

  EWith{} -> err_accum
  EAccum{} -> err_accum
  EZero{} -> err_monoid
  EPlus{} -> err_monoid
  EOneHot{} -> err_monoid
  where
    err_accum = error "Accumulator operations unsupported in the source program"
    err_monoid = error "Monoid operations unsupported in the source program"

    deriv_extremum :: ScalIsNumeric t ~ True
                   => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
                   -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t)))
    deriv_extremum extremum e =
      let STArr (SS n) (STScal t) = typeOf e
          t2 = STPair (STScal t) (STScal t)
          ta2 = STArr (SS n) t2
          tIxN = tTup (sreplicate (SS n) tIx)
      in scalTyCase t
           (ELet ext (dfwdDN e) $
            ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $
              ezip (EVar ext (STArr n (STScal t)) IZ)
                   (ESum1Inner ext
                      {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -}
                      (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $
                         ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $
                           ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext
                                        (EFst ext (EVar ext t2 IZ))
                                        (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ)))
                                                  (EFst ext (EVar ext tIxN (IS IZ)))))))
                             (ESnd ext (EVar ext t2 (IS IZ)))
                             (zeroScalarConst t))))
           (EMaximum1Inner ext (dfwdDN e))