blob: 57ee904baffb6bd1cfd6acc579a4e3d9bd38ad57 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
{-# LANGUAGE GADTs #-}
module Gradient where
import AD
import AST
import qualified Language as L
import Sink
gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double)
gradient func =
let TFun tarr@(TArray sht _) _ = typeof func
idxt = shapeTypeType sht
func' = ad func
in Lam tarr
(Build sht
(Shape (Var tarr Zero))
(Lam idxt
(Snd (App (sinkExp2 func')
(L.zip (Var tarr (Succ Zero))
(L.oneHot sht
(Shape (Var tarr (Succ Zero)))
(Var idxt Zero)))))))
|