diff options
-rw-r--r-- | chad-fast.cabal | 16 | ||||
-rw-r--r-- | src/Array.hs | 3 | ||||
-rw-r--r-- | src/CHAD.hs | 7 | ||||
-rw-r--r-- | src/Data.hs | 5 | ||||
-rw-r--r-- | src/Example.hs | 18 | ||||
-rw-r--r-- | src/Interpreter.hs | 9 | ||||
-rw-r--r-- | test/example/Main.hs | 7 |
7 files changed, 56 insertions, 9 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 1b95c66..f1facf5 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -38,9 +38,13 @@ library -- template-haskell, transformers, vector, - hs-source-dirs: - src - default-language: - Haskell2010 - ghc-options: - -Wall -threaded + hs-source-dirs: src + default-language: Haskell2010 + ghc-options: -Wall + +test-suite example + type: exitcode-stdio-1.0 + main-is: test/example/Main.hs + build-depends: base, chad-fast + default-language: Haskell2010 + ghc-options: -Wall -threaded diff --git a/src/Array.hs b/src/Array.hs index 6473bf0..c48e442 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -59,6 +59,9 @@ arraySize (Array sh _) = shapeSize sh emptyArray :: SNat n -> Array n t emptyArray n = Array (emptyShape n) V.empty +arrayFromList :: Shape n -> [t] -> Array n t +arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/src/CHAD.hs b/src/CHAD.hs index 1ab2da0..12d28e2 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1158,9 +1158,16 @@ drev des = \case EReplicate1Inner{} -> err_unsupported "EReplicate1Inner" EFold1Inner{} -> err_unsupported "EFold1Inner" + ENothing{} -> err_unsupported "ENothing" + EJust{} -> err_unsupported "EJust" + EMaybe{} -> err_unsupported "EMaybe" + EWith{} -> err_accum EAccum{} -> err_accum + EZero{} -> err_monoid + EPlus{} -> err_monoid where err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s diff --git a/src/Data.hs b/src/Data.hs index 4584a53..e951ef2 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -45,6 +45,11 @@ sreplicate (SS n) x = x `SCons` sreplicate n x data Nat = Z | S Nat deriving (Show, Eq, Ord) +type N0 = Z +type N1 = S N0 +type N2 = S N1 +type N3 = S N2 + data SNat n where SZ :: SNat Z SS :: SNat n -> SNat (S n) diff --git a/src/Example.hs b/src/Example.hs index fb4e851..e2f1be9 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -6,10 +6,12 @@ {-# LANGUAGE TypeOperators #-} module Example where +import Array import AST import AST.Pretty import CHAD import Data +import Interpreter import Language import Simplify @@ -172,3 +174,19 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda # let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil + +neuralGo :: (Float + ,(((((), Either () (Array N2 Float, Array N1 Float)) + ,Either () (Array N2 Float, Array N1 Float)) + ,Array N1 Float) + ,Array N1 Float)) +neuralGo = + let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) + lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) + lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] + input = arrayFromList (ShNil `ShCons` 2) [1,1] + in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $ + simplifyN 20 $ + freezeRet mergeDescr + (drev mergeDescr neural) + (EConst ext STF32 1.0) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 01d15f1..316a423 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -14,8 +14,8 @@ {-# LANGUAGE TypeOperators #-} module Interpreter ( interpret, - interpret', - Value, + interpretOpen, + Value(..), ) where import Control.Monad (foldM, join) @@ -39,10 +39,13 @@ runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m interpret :: Ex '[] t -> Rep t -interpret e = runAcM (interpret' SNil e) +interpret = interpretOpen SNil newtype Value t = Value (Rep t) +interpretOpen :: SList Value env -> Ex env t -> Rep t +interpretOpen env e = runAcM (interpret' env e) + interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x diff --git a/test/example/Main.hs b/test/example/Main.hs new file mode 100644 index 0000000..6c36857 --- /dev/null +++ b/test/example/Main.hs @@ -0,0 +1,7 @@ +module Main where + +import Example + + +main :: IO () +main = print neuralGo |