{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
module Example where

import AST
import AST.Pretty
import CHAD
import Data
import Language
import Simplify


-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)


bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)

senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil

descr1 :: Storage a -> Storage b
       -> Descr [TScal TF32, TScal TF32] [b, a]
descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
  where t = STScal STF32

-- x y |- x * y + x
--
-- let x3 = (x1, x2)
--     x4 = ((*) x3, x1)
--  in ( (+) x4
--     , let x5 = 1.0
--           x6 = Inr (x5, x5)
--        in case x6 of
--             Inl x7 -> return ()
--             Inr x8 ->
--               let x9 = fst x8
--                   x10 = Inr (snd x3 * x9, fst x3 * x9)
--                in case x10 of
--                     Inl x11 -> return ()
--                     Inr x12 ->
--                       let x13 = fst x12
--                        in one "v1" x13 >>= \x14 ->
--                             let x15 = snd x12
--                              in one "v2" x15 >>= \x16 ->
--                     let x17 = snd x8
--                      in one "v1" x17)
--
-- ( (x1 * x2) + x1
-- , let x5 = 1.0
--   in do one "v1" (x2 * x5)
--         one "v2" (x1 * x5)
--         one "v1" x5)
ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  x * y + x

-- x y |- let z = x + y in z * (z + x)
ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  let_ (x + y) $ \z ->
  z * (z + x)

-- x y |- if x < y then 2 * x else 3 + x
ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  if_ (x .< y) (2 * x) (3 * x)

-- x y |- if x < y then 2 * x + y * y else 3 + x
ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  if_ (x .< y) (2 * x + y * y) (3 + x)

senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil

descr5 :: Storage a -> Storage b
       -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b)

-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $
  case_ x (\a -> a * y)
          (\b -> b * (y + 1))

senv6 :: SList STy [TScal TI64, TScal TF32]
senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil

descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)

-- x:R n:I |- let a = unit x
--                b = build1 n (\i. let c = idx0 a in c * c)
--            in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $
  let_ (unit x) $ \a ->
  let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b ->
  idx0 (b .! 3)

type R = TScal TF32

senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 =
  let tR = STScal STF32
      tpair = STPair tR tR
  in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil

descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
descr7 =
  let tR = STScal STF32
      tpair = STPair tR tR
  in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge)

-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
-- |- let p1 = snd ps
--        p1' = fst ps
--        x1 = fst p1 * x + snd p1
--        p2 = snd p1'
--        p2' = fst p1'
--        x2 = fst p2 * x + snd p2
--        p3 = snd p2'
--        p3' = fst p2'
--        x3 = fst p3 * x + snd p3
--    in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $
  let tR = STScal STF32
      tpair = STPair tR tR

      layer :: STy p -> SExpr p -> SExpr R -> SExpr R
      layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t =
        let_ (snd_ pars) $ \par ->
        let_ (fst_ pars) $ \restpars ->
        let_ (fst_ par * inp + snd_ par) $ \res ->
        layer t restpars res
      layer STNil _ inp = inp
      layer _ _ _ = error "Invalid layer inputs"

  in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input