{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}
module Main where

import Control.Exception (bracket)
import Control.Monad (when)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
import Data.Foldable (toList)
import Data.Vector.Storable qualified as VS
import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
import Text.Show (showListWith)

import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Nested
import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2, Mixed(M_Primitive), toPrimitive)
import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
import qualified Data.Array.Strided.Arith.Internal as Arith


enableMisc :: Bool
enableMisc = False

bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
bgroupIf True = bgroup
bgroupIf False = \name _ -> bgroup name []


main :: IO ()
main = do
  let enable = False
  bracket (Arith.statisticsEnable enable)
          (\() -> do Arith.statisticsEnable False
                     when enable $ Arith.statisticsPrintAll)
          (\() -> main_tests)

main_tests :: IO ()
main_tests = defaultMain
  [bgroup "compare" tests_compare
  ,bgroup "dotprod" $
    let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides
        dotprodBench name (inp1, inp2) =
          let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
                                             in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
                                      l ""
          in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
                      " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
               nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)

        iota n = riota @Double n
    in
    [dotprodBench "dot 1D"
        (iota 10_000_000
        ,iota 10_000_000)
    ,dotprodBench "revdot"
        (rrev1 (iota 10_000_000)
        ,rrev1 (iota 10_000_000))
    ,dotprodBench "dot 2D"
        (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
        ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
    ,dotprodBench "batched dot"
        (rreplicate (1000 :$: ZSR) (iota 10_000)
        ,rreplicate (1000 :$: ZSR) (iota 10_000))
    ,dotprodBench "transposed dot" $
        let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
                     ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
        in (rtranspose [1,0] a, rtranspose [1,0] b)
    ,dotprodBench "repdot" $
        let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000)
                     ,rreplicate (1000 :$: ZSR) (iota 10_000))
        in (rtranspose [1,0] a, rtranspose [1,0] b)
    ,dotprodBench "matvec" $
        let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
                     ,iota 10_000)
        in (m, rreplicate (1000 :$: ZSR) v)
    ,dotprodBench "vecmat" $
        let (v, m) = (iota 1_000
                     ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
        in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m)
    ,dotprodBench "matmat" $
       let (n,m,k) = (100, 100, 1000)
           (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
                      ,rreshape (m :$: k :$: ZSR) (iota (m*k)))
       in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
          ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2))
    ,dotprodBench "matmatT" $
       let (n,m,k) = (100, 100, 1000)
           (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
                      ,rreshape (k :$: m :$: ZSR) (iota (m*k)))
       in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
          ,rreplicate (n :$: ZSR) m2)
    ]
  ,bgroup "orthotope"
    [bench "normalize [1e6]" $
      let n = 1_000_000
      in nf (\a -> RS.normalize a)
            (RS.rev [0] (RS.iota @Double n))
    ,bench "normalize noop [1e6]" $
      let n = 1_000_000
      in nf (\a -> RS.normalize a)
            (RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
    ]
  ,bgroupIf enableMisc "misc"
    [let n = 1000
         k = 1000
     in bgroup ("fusion [" ++ show k ++ "]*" ++ show n) $
      [bench "sum (concat)" $
        nf (\as -> VS.sum (VS.concat as))
           (replicate n (VS.enumFromTo (1::Int) k))
      ,bench "sum (force (concat))" $
        nf (\as -> VS.sum (VS.force (VS.concat as)))
              (replicate n (VS.enumFromTo (1::Int) k))]
    ,bgroup "concat"
      [bgroup "N"
        [bgroup "hmatrix"
          [bench ("LA.vjoin [500]*1e" ++ show ni) $
            let n = 10 ^ ni
                k = 500
            in nf (\as -> LA.vjoin as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ni <- [1::Int ..5]]
        ,bgroup "vectorStorable"
          [bench ("VS.concat [500]*1e" ++ show ni) $
            let n = 10 ^ ni
                k = 500
            in nf (\as -> VS.concat as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ni <- [1::Int ..5]]
        ]
      ,bgroup "K"
        [bgroup "hmatrix"
          [bench ("LA.vjoin [1e" ++ show ki ++ "]*500") $
            let n = 500
                k = 10 ^ ki
            in nf (\as -> LA.vjoin as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ki <- [1::Int ..5]]
        ,bgroup "vectorStorable"
          [bench ("VS.concat [1e" ++ show ki ++ "]*500") $
            let n = 500
                k = 10 ^ ki
            in nf (\as -> VS.concat as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ki <- [1::Int ..5]]
        ]
      ]
    ]
  ]

tests_compare :: [Benchmark]
tests_compare =
  let n = 1_000_000 in
  [bgroup "Num"
    [bench "sum(+) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(*) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(/) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(**) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(sin) Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a)))
         (riota @Double n)
    ,bench "sum Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 a))
         (riota @Double n)
    ]
  ,bgroup "NumElt"
    [bench "sum(+) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a + b)))
         (riota @Double n, riota n)
    ,bench "sum(*) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
         (riota @Double n, riota n)
    ,bench "sum(/) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
         (riota @Double n, riota n)
    ,bench "sum(**) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a ** b)))
         (riota @Double n, riota n)
    ,bench "sum(sin) Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 (sin a)))
         (riota @Double n)
    ,bench "sum Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 a))
         (riota @Double n)
    ,bench "sum(*) Double [1e6] stride 1; -1" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
         (riota @Double n, rrev1 (riota n))
    ,bench "dotprod Float [1e6]" $
      nf (\(a, b) -> rdot a b)
         (riota @Float n, riota @Float n)
    ,bench "dotprod Float [1e6] stride 1; -1" $
      nf (\(a, b) -> rdot a b)
         (riota @Float n, rrev1 (riota @Float n))
    ,bench "dotprod Double [1e6]" $
      nf (\(a, b) -> rdot a b)
         (riota @Double n, riota @Double n)
    ,bench "dotprod Double [1e6] stride 1; -1" $
      nf (\(a, b) -> rdot a b)
         (riota @Double n, rrev1 (riota @Double n))
    ]
  ,bgroup "hmatrix"
    [bench "sum(+) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a + b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(*) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a * b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(/) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a / b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(**) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a ** b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(sin) Double [1e6]" $
      nf (\a -> LA.sumElements (sin a))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum Double [1e6]" $
      nf (\a -> LA.sumElements a)
         (LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "dotprod Float [1e6]" $
      nf (\(a, b) -> a LA.<.> b)
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
    ,bench "dotprod Double [1e6]" $
      nf (\(a, b) -> a LA.<.> b)
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
    ]
  ]