summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-10-18 22:53:30 +0200
committerTom Smeding <tom@tomsmeding.com>2024-10-18 22:53:30 +0200
commit246439502b78c4a8fcc27ab3296c67471a2b239d (patch)
tree6f4d398114d9cb6b682f8ea14a18f58d850973dc /test
parent6fb15f0b632d3651cc3e6089c20b07b009b578eb (diff)
WIP testing neural
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs33
1 files changed, 31 insertions, 2 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 34ab5af..986c8a0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,15 +1,19 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE TypeApplications #-}
module Main where
import Data.Bifunctor
+import qualified Data.Dependent.Map as DMap
+import Data.Dependent.Map (DMap)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
@@ -21,6 +25,7 @@ import AST
import CHAD
import CHAD.Types
import Data
+import qualified Example
import ForwardAD
import Interpreter
import Interpreter.Rep
@@ -151,6 +156,29 @@ genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
+data TemplateVar n = TemplateVar (SNat n) String
+ deriving (Show)
+
+data Template t where
+ TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
+ TpAny :: STy t -> Template t
+ TpPair :: Template a -> Template b -> Template (TPair a b)
+deriving instance Show (Template t)
+
+data ShapeConstraint n = ShapeAtLeast (Shape n)
+ deriving (Show)
+
+genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
+genTemplate = _
+
+genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
+genEnvTemplateExact shapes env = _
+
+genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
+genEnvTemplate constrs env = do
+ shapes <- DMap.traverseWithKey _ constrs
+ genEnvTemplateExact shapes env
+
showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
@@ -190,7 +218,8 @@ adTest expr = property $ do
tests :: IO Bool
tests = checkParallel $ Group "AD"
- [("id", adTest $ fromNamed $ lambda #x $ body $ #x)]
+ [("id", adTest $ fromNamed $ lambda #x $ body $ #x)
+ ,("neural", adTest Example.neural)]
main :: IO ()
main = defaultMain [tests]