aboutsummaryrefslogtreecommitdiff
path: root/Gradient.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Gradient.hs')
-rw-r--r--Gradient.hs23
1 files changed, 23 insertions, 0 deletions
diff --git a/Gradient.hs b/Gradient.hs
new file mode 100644
index 0000000..57ee904
--- /dev/null
+++ b/Gradient.hs
@@ -0,0 +1,23 @@
+{-# LANGUAGE GADTs #-}
+module Gradient where
+
+import AD
+import AST
+import qualified Language as L
+import Sink
+
+
+gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double)
+gradient func =
+ let TFun tarr@(TArray sht _) _ = typeof func
+ idxt = shapeTypeType sht
+ func' = ad func
+ in Lam tarr
+ (Build sht
+ (Shape (Var tarr Zero))
+ (Lam idxt
+ (Snd (App (sinkExp2 func')
+ (L.zip (Var tarr (Succ Zero))
+ (L.oneHot sht
+ (Shape (Var tarr (Succ Zero)))
+ (Var idxt Zero)))))))