{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module ForwardAD where

import Data.Bifunctor (bimap)

-- import Debug.Trace
-- import AST.Pretty

import Array
import AST
import Data
import ForwardAD.DualNumbers
import Interpreter
import Interpreter.Rep


-- | Tangent along a type (coincides with cotangent for these types)
type family Tan t where
  Tan TNil = TNil
  Tan (TPair a b) = TPair (Tan a) (Tan b)
  Tan (TEither a b) = TEither (Tan a) (Tan b)
  Tan (TMaybe t) = TMaybe (Tan t)
  Tan (TArr n t) = TArr n (Tan t)
  Tan (TScal t) = TanS t

type family TanS t where
  TanS TI32 = TNil
  TanS TI64 = TNil
  TanS TF32 = TScal TF32
  TanS TF64 = TScal TF64
  TanS TBool = TNil

type family TanE env where
  TanE '[] = '[]
  TanE (t : env) = Tan t : TanE env

tanty :: STy t -> STy (Tan t)
tanty STNil = STNil
tanty (STPair a b) = STPair (tanty a) (tanty b)
tanty (STEither a b) = STEither (tanty a) (tanty b)
tanty (STMaybe t) = STMaybe (tanty t)
tanty (STArr n t) = STArr n (tanty t)
tanty (STScal t) = case t of
  STI32 -> STNil
  STI64 -> STNil
  STF32 -> STScal STF32
  STF64 -> STScal STF64
  STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"

zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
zeroTan (STMaybe _) Nothing = Nothing
zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
zeroTan (STArr _ t) x = fmap (zeroTan t) x
zeroTan (STScal STI32) _ = ()
zeroTan (STScal STI64) _ = ()
zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"

tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y
tanScalars (STEither a _) (Left x) = tanScalars a x
tanScalars (STEither _ b) (Right y) = tanScalars b y
tanScalars (STMaybe _) Nothing = []
tanScalars (STMaybe t) (Just x) = tanScalars t x
tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x
tanScalars (STScal STI32) _ = []
tanScalars (STScal STI64) _ = []
tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"

unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =
  let (x, dx) = unzipDN a d1
      (y, dy) = unzipDN b d2
  in ((x, y), (dx, dy))
unzipDN (STEither a b) d = case d of
  Left d1 -> bimap Left Left (unzipDN a d1)
  Right d2 -> bimap Right Right (unzipDN b d2)
unzipDN (STMaybe t) d = case d of
  Nothing -> (Nothing, Nothing)
  Just d' -> bimap Just Just (unzipDN t d')
unzipDN (STArr _ t) d = 
  let pairs = arrayMap (unzipDN t) d
  in (arrayMap fst pairs, arrayMap snd pairs)
unzipDN (STScal ty) d = case ty of
  STI32 -> (d, ())
  STI64 -> (d, ())
  STF32 -> d
  STF64 -> d
  STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"

dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
dotprodTan (STPair a b) (x, y) (x', y') =
  dotprodTan a x x' + dotprodTan b y y'
dotprodTan (STEither a b) x y = case (x, y) of
  (Left x', Left y') -> dotprodTan a x' y'
  (Right x', Right y') -> dotprodTan b x' y'
  _ -> error "dotprodTan: incompatible Either alternatives"
dotprodTan (STMaybe t) x y = case (x, y) of
  (Nothing, Nothing) -> 0.0
  (Just x', Just y') -> dotprodTan t x' y'
  _ -> error "dotprodTan: incompatible Maybe alternatives"
dotprodTan (STArr _ t) x y = 
  let sh1 = arrayShape x
      sh2 = arrayShape y
  in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
        | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
        | otherwise -> error "dotprodTan: incompatible array shapes"
dotprodTan (STScal ty) x y = case ty of
  STI32 -> 0.0
  STI64 -> 0.0
  STF32 -> realToFrac @Float @Double (x * y)
  STF64 -> x * y
  STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"

-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
-- dnConstE STNil _ = ENil ext
-- dnConstE (STPair t1 t2) e =
--   -- This creates fst/snd stacks of unbounded size, but let's not care here
--   EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
-- dnConstE (STEither t1 t2) e =
--   ECase ext e
--     (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
--     (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
-- dnConstE (STMaybe t) e =
--   EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
-- dnConstE (STArr n t) e =
--   EBuild ext n (EShape ext e)
--     (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
-- dnConstE (STScal t) e = case t of
--   STI32 -> e
--   STI64 -> e
--   STF32 -> EPair ext e (EConst ext STF32 0.0)
--   STF64 -> EPair ext e (EConst ext STF64 0.0)
--   STBool -> e
-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"

dnConst :: STy t -> Rep t -> Rep (DN t)
dnConst STNil = const ()
dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
dnConst (STMaybe t) = fmap (dnConst t)
dnConst (STArr _ t) = arrayMap (dnConst t)
dnConst (STScal t) = case t of
  STI32 -> id
  STI64 -> id
  STF32 -> (,0.0)
  STF64 -> (,0.0)
  STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"

-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
-- @t@ input.
type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)

dnOnehots :: STy t -> Rep t -> RevByFwd t
dnOnehots STNil _ = \_ -> ()
dnOnehots (STPair t1 t2) (x, y) =
  \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
dnOnehots (STEither t1 t2) e =
  case e of
    Left x -> \f -> Left (dnOnehots t1 x (f . Left))
    Right y -> \f -> Right (dnOnehots t2 y (f . Right))
dnOnehots (STMaybe t) m =
  case m of
    Nothing -> \_ -> Nothing
    Just x -> \f -> Just (dnOnehots t x (f . Just))
dnOnehots (STArr _ t) a =
  \f ->
    arrayGenerate (arrayShape a) $ \idx ->
      dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
                                                    if i == idx then oh else dnConst t (arrayIndex a i)))
dnOnehots (STScal t) x = case t of
  STI32 -> \_ -> ()
  STI64 -> \_ -> ()
  STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
  STF64 -> \f -> f (x, 1.0)
  STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"

dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val

type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)

dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
dnOnehotEnvs SNil SNil = \_ -> SNil
dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
  \f ->
    Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
    `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))

drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
drevByFwd env expr input dres =
  let outty = typeOf expr
  in -- trace ("fwd: running: " ++ ppExpr (dne env) (dfwdDN expr)) $
     dnOnehotEnvs env input $ \dnInput ->
       -- trace (showEnv (dne env) dnInput) $
       let (_, outtan) = unzipDN outty (interpretOpen False dnInput (dfwdDN expr))
       in dotprodTan outty outtan dres