blob: cccc686e048e2f9d032d852e49db24c75ff436bb (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
{-# LANGUAGE TypeApplications #-}
module Main where
import Control.DeepSeq
import Control.Exception (evaluate)
import Control.Monad (forM_)
import Criterion
import Criterion.Main
import qualified System.Clock as Clock
import System.Environment (getArgs)
import System.Mem (performGC)
import qualified Numeric.AD as AD
import qualified Numeric.AD.Double as AD.Double
import qualified Numeric.ADDual as ADD
import qualified Numeric.ADDual.Array.Internal as ADDA
import Numeric.ADDual.Examples
main :: IO ()
main = do
args <- getArgs
case args of
["--neural-graph"] -> mainNeuralGraph
_ -> mainCriterion
mainCriterion :: IO ()
mainCriterion = defaultMain
[benchNeural 100
,benchNeural 180 -- rather stably 2 GCs
,benchNeural 500
,benchNeural 2000
,benchNeuralA 100
,benchNeuralA 180 -- rather stably 2 GCs
,benchNeuralA 500
,benchNeuralA 2000
]
where
benchNeural :: Int -> Benchmark
benchNeural n =
env (pure (makeNeuralInput n)) $ \input ->
bgroup ("neural-" ++ show n)
[bench "dual" $ nf (\inp -> ADD.gradient' fneural inp 1.0) input
,bench "ad" $ nf (\inp -> AD.grad fneural inp) input
,bench "ad.Double" $ nf (\inp -> AD.Double.grad fneural inp) input]
benchNeuralA :: Int -> Benchmark
benchNeuralA n =
env (pure (makeNeuralInput_A n)) $ \input ->
bgroup ("neuralA-" ++ show n)
[bench "dual" $ nf (\inp -> ADDA.gradient' fneural_A inp 1.0) input]
mainNeuralGraph :: IO ()
mainNeuralGraph = do
forM_ [10, 20 .. 300] $ \n -> do
let input = makeNeuralInput n
_ <- evaluate (force input)
performGC
t1 <- Clock.getTime Clock.Monotonic
_ <- evaluate $ force (ADD.gradient' fneural input 1.0)
t2 <- Clock.getTime Clock.Monotonic
performGC
t3 <- Clock.getTime Clock.Monotonic
_ <- evaluate $ force (AD.grad fneural input)
t4 <- Clock.getTime Clock.Monotonic
let diff a b = fromIntegral (Clock.toNanoSecs (Clock.diffTimeSpec a b)) / 1e9 :: Double
putStrLn $ show n ++ " " ++ show (diff t1 t2) ++ " " ++ show (diff t3 t4)
|