diff options
Diffstat (limited to 'Gradient.hs')
-rw-r--r-- | Gradient.hs | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/Gradient.hs b/Gradient.hs new file mode 100644 index 0000000..57ee904 --- /dev/null +++ b/Gradient.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE GADTs #-} +module Gradient where + +import AD +import AST +import qualified Language as L +import Sink + + +gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double) +gradient func = + let TFun tarr@(TArray sht _) _ = typeof func + idxt = shapeTypeType sht + func' = ad func + in Lam tarr + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Snd (App (sinkExp2 func') + (L.zip (Var tarr (Succ Zero)) + (L.oneHot sht + (Shape (Var tarr (Succ Zero))) + (Var idxt Zero))))))) |