{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
module Example where

import Array
import AST
import AST.Pretty
import CHAD
import CHAD.Top
import ForwardAD
import Interpreter
import Language
import Simplify

import Debug.Trace
import Example.Format
import Example.Types


-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)


bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)

senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil

descr1 :: Storage a -> Storage b
       -> Descr [TScal TF32, TScal TF32] [b, a]
descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
  where t = STScal STF32

-- x y |- x * y + x
--
-- let x3 = (x1, x2)
--     x4 = ((*) x3, x1)
--  in ( (+) x4
--     , let x5 = 1.0
--           x6 = Inr (x5, x5)
--        in case x6 of
--             Inl x7 -> return ()
--             Inr x8 ->
--               let x9 = fst x8
--                   x10 = Inr (snd x3 * x9, fst x3 * x9)
--                in case x10 of
--                     Inl x11 -> return ()
--                     Inr x12 ->
--                       let x13 = fst x12
--                        in one "v1" x13 >>= \x14 ->
--                             let x15 = snd x12
--                              in one "v2" x15 >>= \x16 ->
--                     let x17 = snd x8
--                      in one "v1" x17)
--
-- ( (x1 * x2) + x1
-- , let x5 = 1.0
--   in do one "v1" (x2 * x5)
--         one "v2" (x1 * x5)
--         one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex1 = fromNamed $ lambda #x $ lambda #y $ body $
  #x * #y + #x

-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex2 = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #z (#x + #y) $
  #z * (#z + #x)

-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex3 = fromNamed $ lambda #x $ lambda #y $ body $
  if_ (#x .< #y) (2 * #x) (3 * #x)

-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 = fromNamed $ lambda #x $ lambda #y $ body $
  if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)

senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = knownEnv

descr5 :: Storage a -> Storage b
       -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b)

-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 = fromNamed $ lambda #x $ lambda #y $ body $
  case_ #x (#a :-> #a * #y)
           (#b :-> #b * (#y + 1))

senv6 :: SList STy [TScal TI64, TScal TF32]
senv6 = knownEnv

descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)

-- x:R n:I |- let a = unit x
--                b = build1 n (\i. let c = idx0 a in c * c)
--            in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 = fromNamed $ lambda #x $ lambda #n $ body $
  let_ #a (unit #x) $
  let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
  #b ! pair nil 3

senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 = knownEnv

descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)

-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
-- |- let p1 = snd ps
--        p1' = fst ps
--        x1 = fst p1 * x + snd p1
--        p2 = snd p1'
--        p2' = fst p1'
--        x2 = fst p2 * x + snd p2
--        p3 = snd p2'
--        p3' = fst p2'
--        x3 = fst p3 * x + snd p3
--    in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
  let tR = STScal STF64
      tpair = STPair tR tR

      layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
            => STy p -> NExpr env R
      layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
        let_ #par (snd_ #parstup) $
        let_ #restpars (fst_ #parstup) $
        let_ #inp (fst_ #par * #inp + snd_ #par) $
        let_ #parstup #restpars $
        layer t
      layer STNil = #inp
      layer _ = error "Invalid layer inputs"

  in let_ #parstup #pars123 $
     let_ #inp #input $
     layer (STPair (STPair (STPair STNil tpair) tpair) tpair)

neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
  let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
        -- prod = wei `matmul` x
        let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
                              #wei ! #idx * #x ! pair nil (snd_ #idx)) $
        -- relu (prod + bias)
        build (SS SZ) (shape #prod) $ #idx :->
          let_ #out (#prod ! #idx + #bias ! #idx) $
          if_ (#out .<= const_ 0) (const_ 0) #out

  in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
     let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
     let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
     #x3 ! nil

type NeuralGrad = ((Array N2 Double, Array N1 Double)
                  ,(Array N2 Double, Array N1 Double)
                  ,Array N1 Double
                  ,Array N1 Double)

neuralGo :: (Double  -- primal
            ,NeuralGrad  -- gradient using CHAD
            ,NeuralGrad)  -- gradient using dual-numbers forward AD
neuralGo =
  let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
      lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
      lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
      input = arrayFromList (ShNil `ShCons` 2) [1,1]
      argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
      revderiv =
        simplifyN 20 $
          ELet ext (EConst ext STF64 1.0) $
            chad defaultConfig knownEnv neural
      (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
        (primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
        _ -> undefined
      (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
  in trace (formatter (ppExpr knownEnv revderiv)) $
     (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))