aboutsummaryrefslogtreecommitdiff
path: root/Gradient.hs
blob: 57ee904baffb6bd1cfd6acc579a4e3d9bd38ad57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
{-# LANGUAGE GADTs #-}
module Gradient where

import AD
import AST
import qualified Language as L
import Sink


gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double)
gradient func =
    let TFun tarr@(TArray sht _) _ = typeof func
        idxt = shapeTypeType sht
        func' = ad func
    in Lam tarr
           (Build sht
                  (Shape (Var tarr Zero))
                  (Lam idxt
                     (Snd (App (sinkExp2 func')
                               (L.zip (Var tarr (Succ Zero))
                                      (L.oneHot sht
                                                (Shape (Var tarr (Succ Zero)))
                                                (Var idxt Zero)))))))