diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-10-18 22:53:30 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-18 22:53:30 +0200 |
commit | 246439502b78c4a8fcc27ab3296c67471a2b239d (patch) | |
tree | 6f4d398114d9cb6b682f8ea14a18f58d850973dc | |
parent | 6fb15f0b632d3651cc3e6089c20b07b009b578eb (diff) |
WIP testing neural
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/Example.hs | 20 | ||||
-rw-r--r-- | test/Main.hs | 33 |
3 files changed, 42 insertions, 12 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 3a0de52..ae8ddf4 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -60,6 +60,7 @@ test-suite test build-depends: chad-fast, base, + dependent-map, hedgehog, default-language: Haskell2010 ghc-options: -Wall -threaded diff --git a/src/Example.hs b/src/Example.hs index 6701e38..6e8069c 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -119,7 +119,7 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ idx0 (#b .! 3) -type R = TScal TF32 +type R = TScal TF64 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] senv7 = knownEnv @@ -141,12 +141,12 @@ descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ - let tR = STScal STF32 + let tR = STScal STF64 tpair = STPair tR tR - layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32) + layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) => STy p -> NExpr env R - layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t = + layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = let_ #par (snd_ #parstup) $ let_ #restpars (fst_ #parstup) $ let_ #inp (fst_ #par * #inp + snd_ #par) $ @@ -179,12 +179,12 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda # let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil -type NeuralGrad = ((Array N2 Float, Array N1 Float) - ,(Array N2 Float, Array N1 Float) - ,Array N1 Float - ,Array N1 Float) +type NeuralGrad = ((Array N2 Double, Array N1 Double) + ,(Array N2 Double, Array N1 Double) + ,Array N1 Double + ,Array N1 Double) -neuralGo :: (Float -- primal +neuralGo :: (Double -- primal ,NeuralGrad -- gradient using CHAD ,NeuralGrad) -- gradient using dual-numbers forward AD neuralGo = @@ -197,7 +197,7 @@ neuralGo = simplifyN 20 $ freezeRet mergeDescr (drev mergeDescr neural) - (EConst ext STF32 1.0) + (EConst ext STF64 1.0) (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 in trace (formatter (ppExpr knownEnv revderiv)) $ diff --git a/test/Main.hs b/test/Main.hs index 34ab5af..986c8a0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,15 +1,19 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} module Main where import Data.Bifunctor +import qualified Data.Dependent.Map as DMap +import Data.Dependent.Map (DMap) import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen @@ -21,6 +25,7 @@ import AST import CHAD import CHAD.Types import Data +import qualified Example import ForwardAD import Interpreter import Interpreter.Rep @@ -151,6 +156,29 @@ genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env +data TemplateVar n = TemplateVar (SNat n) String + deriving (Show) + +data Template t where + TpShape :: TemplateVar n -> STy t -> Template (TArr n t) + TpAny :: STy t -> Template t + TpPair :: Template a -> Template b -> Template (TPair a b) +deriving instance Show (Template t) + +data ShapeConstraint n = ShapeAtLeast (Shape n) + deriving (Show) + +genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) +genTemplate = _ + +genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) +genEnvTemplateExact shapes env = _ + +genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) +genEnvTemplate constrs env = do + shapes <- DMap.traverseWithKey _ constrs + genEnvTemplateExact shapes env + showValue :: Int -> STy t -> Rep t -> ShowS showValue _ STNil () = showString "()" showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" @@ -190,7 +218,8 @@ adTest expr = property $ do tests :: IO Bool tests = checkParallel $ Group "AD" - [("id", adTest $ fromNamed $ lambda #x $ body $ #x)] + [("id", adTest $ fromNamed $ lambda #x $ body $ #x) + ,("neural", adTest Example.neural)] main :: IO () main = defaultMain [tests] |