1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
import Control.DeepSeq
import Criterion.Main
import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
import AST
import AST.Count
import AST.UnMonoid
import Array
import qualified CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
import Compile
import Data
import Example
import Example.GMM
import Example.Types
import Interpreter.Rep
import Simplify
gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))
gradCHAD config term =
compile knownEnv $
simplifyFix $ pruneExpr knownEnv $
simplifyFix $ unMonoid $
simplifyFix $
ELet ext (EConst ext STF64 1.0) $
chad' config knownEnv $
simplifyFix term
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
AllNFDataRep '[] = ()
AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)
instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
rnf = go knownEnv
where
go :: SList STy env' -> SList Value env' -> ()
go SNil SNil = ()
go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
withDict @(KnownTy t) t $ rnf v `seq` go ts vs
makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
makeNeuralInputs =
let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
genLayer nin nout =
(genArray (ShNil `ShCons` nout `ShCons` nin)
,genArray (ShNil `ShCons` nout))
in let
nin = 30
n1 = 50
n2 = 50
input = Value (genArray (ShNil `ShCons` nin))
lay1 = Value (genLayer nin n1)
lay2 = Value (genLayer n1 n2)
lay3 = Value (genArray (ShNil `ShCons` n2))
in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]
makeGMMInputs =
let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1]))
-- these values are completely arbitrary lol
kN = 10
kD = 10
kK = 10
i2i64 = fromIntegral @Int @Int64
valpha = genArray (ShNil `ShCons` kK)
vM = genArray (ShNil `ShCons` kK `ShCons` kD)
vQ = genArray (ShNil `ShCons` kK `ShCons` kD)
vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
vX = genArray (ShNil `ShCons` kN `ShCons` kD)
vgamma = 0.42
vm = 2
k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
k2 = 0.5 * vgamma * vgamma
k3 = 0.42 -- don't feel like multigammaing today
in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
Value vm `SCons` vX `SCons`
vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil
accumConfig :: CHADConfig
accumConfig = chcSetAccum CHAD.defaultConfig
bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env))))
=> String -> Ex env R -> SList Value env -> Benchmark
bgroupDefaultAccum name term inputs =
bgroup name
[env (gradCHAD CHAD.defaultConfig term) $ \fun ->
bench "default" (nfAppIO fun inputs)
,env (gradCHAD accumConfig term) $ \fun ->
bench "accum" (nfAppIO fun inputs)
]
main :: IO ()
main = defaultMain
[env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural
,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False)
,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil)
]
|