summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:06 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:06 +0100
commitbb84f6930702a02ba982795e2bb95a64d61f672b (patch)
tree910b2a119f9758115d1b59e45d558fb983a9286b /src
parent02db8c1929a25dda64e6cee7b7343833ee698f34 (diff)
Benchmark GMM
Diffstat (limited to 'src')
-rw-r--r--src/Example.hs6
-rw-r--r--src/Example/GMM.hs5
-rw-r--r--src/Example/Types.hs11
3 files changed, 13 insertions, 9 deletions
diff --git a/src/Example.hs b/src/Example.hs
index a08724b..390031e 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -19,6 +19,7 @@ import Simplify
import Debug.Trace
import Example.Format
+import Example.Types
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
@@ -110,8 +111,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $
let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
#b ! pair nil 3
-type R = TScal TF64
-
senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 = knownEnv
@@ -150,9 +149,6 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
let_ #inp #input $
layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
-type TVec = TArr (S Z)
-type TMat = TArr (S (S Z))
-
neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index 1db88bd..12bbd98 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -3,13 +3,10 @@
{-# LANGUAGE TypeApplications #-}
module Example.GMM where
+import Example.Types
import Language
-type R = TScal TF64
-type I64 = TScal TI64
-type TVec = TArr (S Z)
-type TMat = TArr (S (S Z))
-- N, D, K: integers > 0
-- alpha, M, Q, L: the active parameters
diff --git a/src/Example/Types.hs b/src/Example/Types.hs
new file mode 100644
index 0000000..d63159b
--- /dev/null
+++ b/src/Example/Types.hs
@@ -0,0 +1,11 @@
+{-# LANGUAGE DataKinds #-}
+module Example.Types where
+
+import AST
+import Data
+
+
+type R = TScal TF64
+type I64 = TScal TI64
+type TVec = TArr (S Z)
+type TMat = TArr (S (S Z))