summaryrefslogtreecommitdiff
path: root/bench/Main.hs
blob: 932da9d455ec14f1f4acce11654fbd0ff74681ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}

{-# OPTIONS -Wno-orphans #-}
module Main where

import Control.DeepSeq
import Criterion.Main
import Data.Coerce
import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)

import AST
import Array
import qualified CHAD (defaultConfig)
import CHAD (CHADConfig(..))
import CHAD.Top
import CHAD.Types
import Data
import Example
import Example.GMM
import Example.Types
import Interpreter
import Interpreter.Rep
import Simplify


gradCHAD :: KnownEnv env => CHADConfig -> SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
gradCHAD config input ctg term =
  interpretOpen False input $
    simplifyFix $
      ELet ext (EConst ext STF64 ctg) $ chad' config knownEnv term

instance KnownTy t => NFData (Value t) where
  rnf = \(Value x) -> go (knownTy @t) x
    where
      go :: STy t' -> Rep t' -> ()
      go STNil () = ()
      go (STPair a b) (x, y) = go a x `seq` go b y
      go (STEither a _) (Left x) = go a x
      go (STEither _ b) (Right y) = go b y
      go (STMaybe _) Nothing = ()
      go (STMaybe t) (Just x) = go t x
      go (STArr (_ :: SNat n) (t :: STy t2)) arr =
        withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
      go (STScal t) x = case t of
        STI32 -> rnf x
        STI64 -> rnf x
        STF32 -> rnf x
        STF64 -> rnf x
        STBool -> rnf x
      go STAccum{} _ = error "Cannot rnf accumulators"

type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
  AllNFDataRep '[] = ()
  AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)

instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
  rnf = go knownEnv
    where
      go :: SList STy env' -> SList Value env' -> ()
      go SNil SNil = ()
      go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
        withDict @(KnownTy t) t $ rnf v `seq` go ts vs

makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
makeNeuralInputs =
  let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
      genLayer nin nout =
        (genArray (ShNil `ShCons` nout `ShCons` nin)
        ,genArray (ShNil `ShCons` nout))
  in let
      nin = 30
      n1 = 50
      n2 = 50
      input = Value (genArray (ShNil `ShCons` nin))
      lay1 = Value (genLayer nin n1)
      lay2 = Value (genLayer n1 n2)
      lay3 = Value (genArray (ShNil `ShCons` n2))
  in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil

makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]
makeGMMInputs =
  let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1]))
      -- these values are completely arbitrary lol
      kN = 10
      kD = 10
      kK = 10
      i2i64 = fromIntegral @Int @Int64
      valpha = genArray (ShNil `ShCons` kK)
      vM = genArray (ShNil `ShCons` kK `ShCons` kD)
      vQ = genArray (ShNil `ShCons` kK `ShCons` kD)
      vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
      vX = genArray (ShNil `ShCons` kN `ShCons` kD)
      vgamma = 0.42
      vm = 2
      k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
      k2 = 0.5 * vgamma * vgamma
      k3 = 0.42  -- don't feel like multigammaing today
  in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
     Value vm `SCons` vX `SCons`
     vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
     Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
     SNil

accumConfig :: CHADConfig
accumConfig = CHADConfig
  { chcLetArrayAccum = True
  , chcCaseArrayAccum = True }

main :: IO ()
main = defaultMain
  [env (return makeNeuralInputs) $ \inputs ->
    bench "neural" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg neural) (inputs, 1.0))
  ,env (return makeNeuralInputs) $ \inputs ->
    bench "neural-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg neural) (inputs, 1.0))
  ,env (return makeGMMInputs) $ \inputs ->
    bench "gmm" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg (gmmObjective False)) (inputs, 1.0))
  ,env (return makeGMMInputs) $ \inputs ->
    bench "gmm-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg (gmmObjective False)) (inputs, 1.0))
  ]