summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/Example.hs b/src/Example.hs
index fb4e851..e2f1be9 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -6,10 +6,12 @@
{-# LANGUAGE TypeOperators #-}
module Example where
+import Array
import AST
import AST.Pretty
import CHAD
import Data
+import Interpreter
import Language
import Simplify
@@ -172,3 +174,19 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil
+
+neuralGo :: (Float
+ ,(((((), Either () (Array N2 Float, Array N1 Float))
+ ,Either () (Array N2 Float, Array N1 Float))
+ ,Array N1 Float)
+ ,Array N1 Float))
+neuralGo =
+ let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
+ input = arrayFromList (ShNil `ShCons` 2) [1,1]
+ in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $
+ simplifyN 20 $
+ freezeRet mergeDescr
+ (drev mergeDescr neural)
+ (EConst ext STF32 1.0)