{-# LANGUAGE GADTs #-} module Gradient where import AD import AST import qualified Language as L import Sink gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double) gradient func = let TFun tarr@(TArray sht _) _ = typeof func idxt = shapeTypeType sht func' = ad func in Lam tarr (Build sht (Shape (Var tarr Zero)) (Lam idxt (Snd (App (sinkExp2 func') (L.zip (Var tarr (Succ Zero)) (L.oneHot sht (Shape (Var tarr (Succ Zero))) (Var idxt Zero)))))))