{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Main where

import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Framework

import Array
import AST
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import CHAD.Types.ToTan
import Compile
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
import ForwardAD.DualNumbers
import Interpreter
import Interpreter.Rep
import Language
import Simplify


type R = TScal TF64


data SimplIters = SimplIters Int | SimplFix
  deriving (Show)

simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t
simplifyIters iters env | Dict <- envKnown env =
  case iters of
    SimplIters n -> simplifyN n
    SimplFix -> simplifyFix

-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
  let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
      (out, grad) = interpretOpen False input dterm
  in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))

-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' simplIters env term input =
  second (second (toTanE env input)) $
    gradientByCHAD simplIters env term input

gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0

extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
extendDN (STEither a _) (Left x) = Left <$> extendDN a x
extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
extendDN (STScal sty) x = case sty of
  STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
  STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
  STI32 -> pure x
  STI64 -> pure x
  STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"

extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val

closeIsh' :: Double -> Double -> Double -> Bool
closeIsh' h a b =
  abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h)

closeIsh :: Double -> Double -> Bool
closeIsh = closeIsh' 1e-5

closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool
closeIshT' _ STNil () () = True
closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y'
closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
closeIshT' _ STEither{} _ _ = False
closeIshT' _ (STMaybe _) Nothing Nothing = True
closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
closeIshT' _ STMaybe{} _ _ = False
closeIshT' h (STArr _ a) arr1 arr2 =
  arrayShape arr1 == arrayShape arr2 &&
    and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2))
closeIshT' _ (STScal STI32) i j = i == j
closeIshT' _ (STScal STI64) i j = i == j
closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
closeIshT' h (STScal STF64) x y = closeIsh' h x y
closeIshT' _ (STScal STBool) x y = x == y
closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"

closeIshT :: STy t -> Rep t -> Rep t -> Bool
closeIshT = closeIshT' 1e-5

data a :$ b = a :$ b deriving (Show) ; infixl :$

-- An empty name means "no restrictions".
data TplConstr = C String  -- ^ name; @""@ means anonymous
                   Int     -- ^ minimum value to generate

type family DimNames n where
  DimNames Z = ()
  DimNames (S Z) = TplConstr
  DimNames (S n) = DimNames n :$ TplConstr

type family Tpl t where
  Tpl (TArr n t) = DimNames n
  Tpl (TPair a b) = (Tpl a, Tpl b)
  -- If you add equations here, don't forget to update genValue! It currently
  -- just emptyTpl's things out.
  Tpl _ = ()

data a :& b = a :& b deriving (Show) ; infixl :&

type family TemplateE env where
  TemplateE '[] = ()
  TemplateE '[t] = Tpl t
  TemplateE (t : ts) = TemplateE ts :& Tpl t

emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
emptyDimNames (SS SZ) = C "" 0
emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0

emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
emptyTpl (STScal _) = ()
emptyTpl _ = error "too lazy"

emptyTemplateE :: SList STy env -> TemplateE env
emptyTemplateE SNil = ()
emptyTemplateE (t `SCons` SNil) = emptyTpl t
emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t

genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShape = \n tpl -> do
  sh <- genShapeNaive n tpl
  let sz = shapeSize sh
      factor = sz `div` 100 + 1
  return (shapeDiv sh factor)
  where
    genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
    genShapeNaive SZ () = return ShNil
    genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
    genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name

    genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
    genNamedDim (C "" lo) = genDim lo
    genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
      Nothing -> do
        dim <- genDim lo
        modify (Map.insert name dim)
        return dim
      Just dim -> return dim

    genDim :: Int -> StateT (Map String Int) Gen Int
    genDim lo = Gen.integral (Range.linear lo 10)

    shapeDiv :: Shape n -> Int -> Shape n
    shapeDiv ShNil _ = ShNil
    shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)

genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
  Value <$> arrayGenerateLinM sh (\_ ->
              unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)

genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
  STNil -> return (Value ())
  STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
  STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
                             ,liftV Right <$> genValue b (emptyTpl b)]
  STMaybe t -> Gen.choice [return (Value Nothing)
                          ,liftV Just <$> genValue t (emptyTpl t)]
  STArr n t -> genShape n tpl >>= lift . genArray t
  STScal sty -> case sty of
    STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
    STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
    STBool -> Gen.choice [return (Value False), return (Value True)]
  STAccum{} -> error "Cannot generate inputs for accumulators"

genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl

data TypedValue t = TypedValue (STy t) (Rep t)
instance Show (TypedValue t) where
  showsPrec d (TypedValue t x) = showValue d t x

compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr

compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree
compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty)

compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree
compileTestGen name expr envGenerator =
  let env = knownEnv
      t = typeOf expr
  in withCompiled env expr $ \fun ->
     testProperty name $ property $ do
       input <- forAllWith (showEnv env) envGenerator
       let resI = interpretOpen False input expr
       resC <- liftIO $ fun input
       let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
       diff (TypedValue t resI) cmp (TypedValue t resC)

adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree
adTest name = adTestCon name (const True)

adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree
adTestCon name constr term =
  let env = knownEnv
  in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))

adTestTp :: forall env. KnownEnv env
         => TestName -> TemplateE env -> Ex env R -> TestTree
adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)

adTestGen :: forall env. KnownEnv env
          => TestName -> Ex env R -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
  let env = knownEnv @env
      exprS = simplifyFix expr
  in withCompiled env expr $ \primalfun ->
     withCompiled env (simplifyFix expr) $ \primalSfun ->
     testGroupCollapse name
       [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
       ,adTestGenFwd env envGenerator exprS
       ,adTestGenChad env envGenerator expr exprS primalSfun]

adTestGenPrimal :: SList STy env -> Gen (SList Value env)
                -> Ex env R -> Ex env R
                -> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
                -> TestTree
adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
  testProperty "compile primal" $ property $ do
    input <- forAllWith (showEnv env) envGenerator

    let outPrimalI = interpretOpen False input expr
    outPrimalC <- liftIO $ primalfun input
    diff outPrimalI (closeIsh' 1e-8) outPrimalC

    let outPrimalSI = interpretOpen False input exprS
    outPrimalSC <- liftIO $ primalSfun input
    diff outPrimalSI (closeIsh' 1e-8) outPrimalSC

adTestGenFwd :: SList STy env -> Gen (SList Value env)
             -> Ex env R
             -> TestTree
adTestGenFwd env envGenerator exprS =
  withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
    testProperty "compile fwdAD" $ property $ do
      input <- forAllWith (showEnv env) envGenerator
      dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
      let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS)
      (outDNC1, outDNC2) <- liftIO $ dnfun dinput
      diff outDNI1 (closeIsh' 1e-8) outDNC1
      diff outDNI2 (closeIsh' 1e-8) outDNC2

adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
              -> Ex env R -> Ex env R
              -> (SList Value env -> IO Double)
              -> TestTree
adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
  let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
      dtermChadS = simplifyFix dtermChad0
      dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
      dtermSChadS = simplifyFix dtermSChad0
  in
  withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
  withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->
    testProperty "chad" $ property $ do
      annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))

      -- pack Text for less GC pressure (these values are retained for some reason)
      diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
      diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))

      input <- forAllWith (showEnv env) envGenerator
      outPrimal <- liftIO $ primalSfun input

      let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
          unpackGrad = unTup vUnpair (d2e env) . Value

      let scFwd = tanEScalars env $ gradientByForward fwdartifactC input

      let (outChad0 , gradChad0)  = second unpackGrad $ interpretOpen False input dtermChad0
          (outChadS , gradChadS)  = second unpackGrad $ interpretOpen False input dtermChadS
          (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
          (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
          scChad   = tanEScalars env $ toTanE env input gradChad0
          scChadS  = tanEScalars env $ toTanE env input gradChadS
          scSChad  = tanEScalars env $ toTanE env input gradSChad0
          scSChadS = tanEScalars env $ toTanE env input gradSChadS

      (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input)
      let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS

      -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
      -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
      -- annotate (ppExpr knownEnv expr)
      -- annotate (ppExpr env dtermChad0)
      -- annotate (ppExpr env dtermChadS)
      annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
      diff outChad0      closeIsh outPrimal
      diff outChadS      closeIsh outPrimal
      diff outSChad0     closeIsh outPrimal
      diff outSChadS     closeIsh outPrimal
      diff outCompSChadS closeIsh outPrimal
      -- TODO: use closeIshT
      let closeIshList x y = and (zipWith closeIsh x y)
      diff scChad       closeIshList scFwd
      diff scChadS      closeIshList scFwd
      diff scSChad      closeIshList scFwd
      diff scSChadS     closeIshList scFwd
      diff scCompSChadS closeIshList scFwd

withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())

term_build1_sum :: Ex '[TArr N1 R] R
term_build1_sum = fromNamed $ lambda #x $ body $
  idx0 $ sum1i $
    build (SS SZ) (shape #x) $ #idx :-> #x ! #idx

term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
  let_ #p (pair #x #y) $
  let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
    fst_ #q * #x + snd_ #q * fst_ #p

term_sparse :: Ex '[TArr N1 R] R
term_sparse = fromNamed $ lambda #inp $ body $
  let_ #n (snd_ (shape #inp)) $
  let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
  let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $
  let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $
  let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
    idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)

term_regression_simpl1 :: Ex '[TArr N1 R] R
term_regression_simpl1 = fromNamed $ lambda #q $ body $
  idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
    let_ #j (snd_ #idx) $
      if_ (#j .== 0)
        (#q ! pair nil 0)
        (if_ (#j .== #j) 1.0 2.0)

term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R
term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
  idx0 $ sum1i $
    let_ #hei (snd_ (fst_ (shape #mat))) $
    let_ #wid (snd_ (shape #mat)) $
      build1 #hei $ #i :->
        idx0 (sum1i (build1 #wid $ #j :->
                       #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))

tests_Compile :: TestTree
tests_Compile = testGroup "Compile"
  [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $
      with @R 0.0 $ #ac :->
        if_ #b (accum SAPHere nil #x #ac)
               nil

  ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
      with @(TPair R R) nothing $ #ac :->
        let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
        let_ #_ (accum SAPHere nil #x #ac) $
        let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
          nil

  ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $
      let_ #len (snd_ (shape #x)) $
      with @(TArr N1 R) nothing $ #ac :->
        let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac)
                        nil) $
        let_ #_ (accum SAPHere nil (just #x) #ac) $
          nil
  ]

tests_AD :: TestTree
tests_AD = testGroup "AD"
  [adTest "id" $ fromNamed $ lambda #x $ body $ #x

  ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x

  ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $
      let_ #i (round_ #x) $
      let_ #j (round_ #y) $
      let_ #a1 (#x + #y) $
      let_ #a2 (#x - #y) $
      let_ #a3 (#x * #y) $
      let_ #a4 (#x / (#y * #y + 1)) $
      let_ #b1 (#i + #j) $
      let_ #b2 (#i - #j) $
      let_ #b3 (#i * #j) $
      let_ #b4 (#i `idiv` (#j * #j + 1)) $
        #a1 + #a2 + #a3 + #a4 +
        toFloat_ (#b1 + #b2 + #b3 + #b4)

  ,adTest "order-of-operations" $ fromNamed $ body $
      toFloat_ (3 * (3 `idiv` 2))  -- Compile had a pretty-printing bug at some point

  ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)

  ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $
      idx0 $ sum1i $ replicate1i 10 #x

  ,adTest "pairs" term_pairs

  ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
      idx0 $ build SZ nil $ #idx :-> const_ 0.0

  ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
      idx0 $
        build SZ (shape #x) $ #idx :-> #x ! #idx

  ,adTest "build1-sum" term_build1_sum

  ,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $
      idx0 $ sum1i . sum1i $
        build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx

  ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
      fromNamed $ lambda @(TArr N2 R) #x $ body $
        idx0 $ sum1i $ maximum1i #x

  ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
      fromNamed $ lambda @(TArr N2 R) #x $ body $
        idx0 $ sum1i $ minimum1i #x

  ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $
      let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
        42

  ,adTestTp "sparse" (C "" 5) term_sparse

  -- Regression test for a simplifier bug (89b78d4)
  ,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1

  ,adTestGen "neural" Example.neural genNeural

  ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural

  ,adTestTp "logsumexp" (C "" 1) $
      fromNamed $ lambda @(TArr N1 _) #vec $ body $
      let_ #m (maximum1i #vec) $
        log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m

  ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec

  ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM

  ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM

  ,adTestGen "gmm" (Example.gmmObjective False) genGMM

  ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
  ]
  where
    genGMM = do
      -- The input ranges here are completely arbitrary.
      let tR = STScal STF64
      kN <- Gen.integral (Range.linear 1 8)
      kD <- Gen.integral (Range.linear 1 8)
      kK <- Gen.integral (Range.linear 1 8)
      let i2i64 = fromIntegral @Int @Int64
      valpha <- genArray tR (ShNil `ShCons` kK)
      vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
      vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
      vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
      vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
      vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
      vm <- Gen.integral (Range.linear 0 5)
      let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
          k2 = 0.5 * vgamma * vgamma
          k3 = 0.42  -- don't feel like multigammaing today
      return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
              Value vm `SCons` vX `SCons`
              vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
              Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
              SNil)

    genNeural = do
      let tR = STScal STF64
      let genLayer nin nout =
            liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
                       <*> genArray tR (ShNil `ShCons` nout)
      nin <- Gen.integral (Range.linear 1 10)
      n1 <- Gen.integral (Range.linear 1 10)
      n2 <- Gen.integral (Range.linear 1 10)
      input <- genArray tR (ShNil `ShCons` nin)
      lay1 <- genLayer nin n1
      lay2 <- genLayer n1 n2
      lay3 <- genArray tR (ShNil `ShCons` n2)
      return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)

main :: IO ()
main = defaultMain $ testGroup "All"
  [tests_Compile
  ,tests_AD]