From 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 24 Jun 2021 23:14:54 +0200 Subject: Initial --- Gradient.hs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 Gradient.hs (limited to 'Gradient.hs') diff --git a/Gradient.hs b/Gradient.hs new file mode 100644 index 0000000..57ee904 --- /dev/null +++ b/Gradient.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE GADTs #-} +module Gradient where + +import AD +import AST +import qualified Language as L +import Sink + + +gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double) +gradient func = + let TFun tarr@(TArray sht _) _ = typeof func + idxt = shapeTypeType sht + func' = ad func + in Lam tarr + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Snd (App (sinkExp2 func') + (L.zip (Var tarr (Succ Zero)) + (L.oneHot sht + (Shape (Var tarr (Succ Zero))) + (Var idxt Zero))))))) -- cgit v1.2.3-70-g09d2