summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/Example.hs20
-rw-r--r--test/Main.hs33
3 files changed, 42 insertions, 12 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 3a0de52..ae8ddf4 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -60,6 +60,7 @@ test-suite test
build-depends:
chad-fast,
base,
+ dependent-map,
hedgehog,
default-language: Haskell2010
ghc-options: -Wall -threaded
diff --git a/src/Example.hs b/src/Example.hs
index 6701e38..6e8069c 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -119,7 +119,7 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $
let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
idx0 (#b .! 3)
-type R = TScal TF32
+type R = TScal TF64
senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
senv7 = knownEnv
@@ -141,12 +141,12 @@ descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-- in x3
ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
- let tR = STScal STF32
+ let tR = STScal STF64
tpair = STPair tR tR
- layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32)
+ layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
=> STy p -> NExpr env R
- layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t =
+ layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
let_ #par (snd_ #parstup) $
let_ #restpars (fst_ #parstup) $
let_ #inp (fst_ #par * #inp + snd_ #par) $
@@ -179,12 +179,12 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil
-type NeuralGrad = ((Array N2 Float, Array N1 Float)
- ,(Array N2 Float, Array N1 Float)
- ,Array N1 Float
- ,Array N1 Float)
+type NeuralGrad = ((Array N2 Double, Array N1 Double)
+ ,(Array N2 Double, Array N1 Double)
+ ,Array N1 Double
+ ,Array N1 Double)
-neuralGo :: (Float -- primal
+neuralGo :: (Double -- primal
,NeuralGrad -- gradient using CHAD
,NeuralGrad) -- gradient using dual-numbers forward AD
neuralGo =
@@ -197,7 +197,7 @@ neuralGo =
simplifyN 20 $
freezeRet mergeDescr
(drev mergeDescr neural)
- (EConst ext STF32 1.0)
+ (EConst ext STF64 1.0)
(primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
in trace (formatter (ppExpr knownEnv revderiv)) $
diff --git a/test/Main.hs b/test/Main.hs
index 34ab5af..986c8a0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,15 +1,19 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE TypeApplications #-}
module Main where
import Data.Bifunctor
+import qualified Data.Dependent.Map as DMap
+import Data.Dependent.Map (DMap)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
@@ -21,6 +25,7 @@ import AST
import CHAD
import CHAD.Types
import Data
+import qualified Example
import ForwardAD
import Interpreter
import Interpreter.Rep
@@ -151,6 +156,29 @@ genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
+data TemplateVar n = TemplateVar (SNat n) String
+ deriving (Show)
+
+data Template t where
+ TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
+ TpAny :: STy t -> Template t
+ TpPair :: Template a -> Template b -> Template (TPair a b)
+deriving instance Show (Template t)
+
+data ShapeConstraint n = ShapeAtLeast (Shape n)
+ deriving (Show)
+
+genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
+genTemplate = _
+
+genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
+genEnvTemplateExact shapes env = _
+
+genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
+genEnvTemplate constrs env = do
+ shapes <- DMap.traverseWithKey _ constrs
+ genEnvTemplateExact shapes env
+
showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
@@ -190,7 +218,8 @@ adTest expr = property $ do
tests :: IO Bool
tests = checkParallel $ Group "AD"
- [("id", adTest $ fromNamed $ lambda #x $ body $ #x)]
+ [("id", adTest $ fromNamed $ lambda #x $ body $ #x)
+ ,("neural", adTest Example.neural)]
main :: IO ()
main = defaultMain [tests]