summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal16
-rw-r--r--src/Array.hs3
-rw-r--r--src/CHAD.hs7
-rw-r--r--src/Data.hs5
-rw-r--r--src/Example.hs18
-rw-r--r--src/Interpreter.hs9
-rw-r--r--test/example/Main.hs7
7 files changed, 56 insertions, 9 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 1b95c66..f1facf5 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -38,9 +38,13 @@ library
-- template-haskell,
transformers,
vector,
- hs-source-dirs:
- src
- default-language:
- Haskell2010
- ghc-options:
- -Wall -threaded
+ hs-source-dirs: src
+ default-language: Haskell2010
+ ghc-options: -Wall
+
+test-suite example
+ type: exitcode-stdio-1.0
+ main-is: test/example/Main.hs
+ build-depends: base, chad-fast
+ default-language: Haskell2010
+ ghc-options: -Wall -threaded
diff --git a/src/Array.hs b/src/Array.hs
index 6473bf0..c48e442 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -59,6 +59,9 @@ arraySize (Array sh _) = shapeSize sh
emptyArray :: SNat n -> Array n t
emptyArray n = Array (emptyShape n) V.empty
+arrayFromList :: Shape n -> [t] -> Array n t
+arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l)
+
arrayUnit :: t -> Array Z t
arrayUnit x = Array ShNil (V.singleton x)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 1ab2da0..12d28e2 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1158,9 +1158,16 @@ drev des = \case
EReplicate1Inner{} -> err_unsupported "EReplicate1Inner"
EFold1Inner{} -> err_unsupported "EFold1Inner"
+ ENothing{} -> err_unsupported "ENothing"
+ EJust{} -> err_unsupported "EJust"
+ EMaybe{} -> err_unsupported "EMaybe"
+
EWith{} -> err_accum
EAccum{} -> err_accum
+ EZero{} -> err_monoid
+ EPlus{} -> err_monoid
where
err_accum = error "Accumulator operations unsupported in the source program"
+ err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
diff --git a/src/Data.hs b/src/Data.hs
index 4584a53..e951ef2 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -45,6 +45,11 @@ sreplicate (SS n) x = x `SCons` sreplicate n x
data Nat = Z | S Nat
deriving (Show, Eq, Ord)
+type N0 = Z
+type N1 = S N0
+type N2 = S N1
+type N3 = S N2
+
data SNat n where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
diff --git a/src/Example.hs b/src/Example.hs
index fb4e851..e2f1be9 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -6,10 +6,12 @@
{-# LANGUAGE TypeOperators #-}
module Example where
+import Array
import AST
import AST.Pretty
import CHAD
import Data
+import Interpreter
import Language
import Simplify
@@ -172,3 +174,19 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil
+
+neuralGo :: (Float
+ ,(((((), Either () (Array N2 Float, Array N1 Float))
+ ,Either () (Array N2 Float, Array N1 Float))
+ ,Array N1 Float)
+ ,Array N1 Float))
+neuralGo =
+ let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
+ input = arrayFromList (ShNil `ShCons` 2) [1,1]
+ in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $
+ simplifyN 20 $
+ freezeRet mergeDescr
+ (drev mergeDescr neural)
+ (EConst ext STF32 1.0)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 01d15f1..316a423 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -14,8 +14,8 @@
{-# LANGUAGE TypeOperators #-}
module Interpreter (
interpret,
- interpret',
- Value,
+ interpretOpen,
+ Value(..),
) where
import Control.Monad (foldM, join)
@@ -39,10 +39,13 @@ runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
interpret :: Ex '[] t -> Rep t
-interpret e = runAcM (interpret' SNil e)
+interpret = interpretOpen SNil
newtype Value t = Value (Rep t)
+interpretOpen :: SList Value env -> Ex env t -> Rep t
+interpretOpen env e = runAcM (interpret' env e)
+
interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
diff --git a/test/example/Main.hs b/test/example/Main.hs
new file mode 100644
index 0000000..6c36857
--- /dev/null
+++ b/test/example/Main.hs
@@ -0,0 +1,7 @@
+module Main where
+
+import Example
+
+
+main :: IO ()
+main = print neuralGo