From 246439502b78c4a8fcc27ab3296c67471a2b239d Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Fri, 18 Oct 2024 22:53:30 +0200
Subject: WIP testing neural

---
 src/Example.hs | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

(limited to 'src/Example.hs')

diff --git a/src/Example.hs b/src/Example.hs
index 6701e38..6e8069c 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -119,7 +119,7 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $
   let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
   idx0 (#b .! 3)
 
-type R = TScal TF32
+type R = TScal TF64
 
 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
 senv7 = knownEnv
@@ -141,12 +141,12 @@ descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
 --    in x3
 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
 ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
-  let tR = STScal STF32
+  let tR = STScal STF64
       tpair = STPair tR tR
 
-      layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32)
+      layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
             => STy p -> NExpr env R
-      layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t =
+      layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
         let_ #par (snd_ #parstup) $
         let_ #restpars (fst_ #parstup) $
         let_ #inp (fst_ #par * #inp + snd_ #par) $
@@ -179,12 +179,12 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
      let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
      #x3 ! nil
 
-type NeuralGrad = ((Array N2 Float, Array N1 Float)
-                  ,(Array N2 Float, Array N1 Float)
-                  ,Array N1 Float
-                  ,Array N1 Float)
+type NeuralGrad = ((Array N2 Double, Array N1 Double)
+                  ,(Array N2 Double, Array N1 Double)
+                  ,Array N1 Double
+                  ,Array N1 Double)
 
-neuralGo :: (Float  -- primal
+neuralGo :: (Double  -- primal
             ,NeuralGrad  -- gradient using CHAD
             ,NeuralGrad)  -- gradient using dual-numbers forward AD
 neuralGo =
@@ -197,7 +197,7 @@ neuralGo =
         simplifyN 20 $
         freezeRet mergeDescr
           (drev mergeDescr neural)
-          (EConst ext STF32 1.0)
+          (EConst ext STF64 1.0)
       (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv
       (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
   in trace (formatter (ppExpr knownEnv revderiv)) $
-- 
cgit v1.2.3-70-g09d2