diff options
Diffstat (limited to 'test/Main.hs')
-rw-r--r-- | test/Main.hs | 93 |
1 files changed, 8 insertions, 85 deletions
diff --git a/test/Main.hs b/test/Main.hs index e7dda69..d3e55b3 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,19 +1,15 @@ {-# LANGUAGE DataKinds #-} --- {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} --- {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Main where import Data.Bifunctor --- import qualified Data.Dependent.Map as DMap --- import Data.Dependent.Map (DMap) import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen @@ -23,7 +19,7 @@ import Hedgehog.Main import Array import AST import AST.Pretty -import CHAD +import CHAD.Top import CHAD.Types import Data import qualified Example @@ -34,63 +30,19 @@ import Language import Simplify -type family MapMerge env where - MapMerge '[] = '[] - MapMerge (t : ts) = "merge" : MapMerge ts - -mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] -mapMergeNoAccum SNil = Refl -mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl - -mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env -mapMergeOnlyMerge SNil = Refl -mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl - -primalEnv :: SList STy env' -> SList STy (D1E env') -primalEnv SNil = SNil -primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env - data SimplIters = SimplIters Int | SimplFix deriving (Show) -diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64) - -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env))) -diffCHAD = \simplIters env term -> - case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of - (Refl, Refl, Dict) -> - let descr = makeMergeDescr env - simpl = case simplIters of - SimplIters n -> simplifyN n - SimplFix -> simplifyFix - in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0) - where - makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') - makeMergeDescr SNil = DTop - makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) - -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> - case (mapMergeNoAccum env, mapMergeOnlyMerge env) of - (Refl, Refl) -> - let dterm = diffCHAD simplIters env term - input1 = toPrimalE env input - (out, grad) = interpretOpen False input1 dterm - in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad))) - where - toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') - toPrimalE SNil SNil = SNil - toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp - - toPrimal :: STy t -> Rep t -> Rep (D1 t) - toPrimal = \case - STNil -> id - STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) - STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) - STMaybe t -> fmap (toPrimal t) - STArr _ t -> fmap (toPrimal t) - STScal _ -> id - STAccum{} -> error "Accumulators not allowed in input program" + let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term + dterm | Dict <- envKnown env + = case simplIters of + SimplIters n -> simplifyN n dtermNonSimpl + SimplFix -> simplifyFix dtermNonSimpl + (out, grad) = interpretOpen False input dterm + in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) @@ -172,29 +124,6 @@ genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env --- data TemplateVar n = TemplateVar (SNat n) String --- deriving (Show) - --- data Template t where --- TpShape :: TemplateVar n -> STy t -> Template (TArr n t) --- TpAny :: STy t -> Template t --- TpPair :: Template a -> Template b -> Template (TPair a b) --- deriving instance Show (Template t) - --- data ShapeConstraint n = ShapeAtLeast (Shape n) --- deriving (Show) - --- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) --- genTemplate = _ - --- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) --- genEnvTemplateExact shapes env = _ - --- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) --- genEnvTemplate constrs env = do --- shapes <- DMap.traverseWithKey _ constrs --- genEnvTemplateExact shapes env - showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" where @@ -205,11 +134,6 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property adTest = flip adTestGen (genEnv (knownEnv @env)) --- adTestTp :: forall env. KnownEnv env --- => DMap TemplateVar ShapeConstraint -> SList Template env --- -> Ex env (TScal TF64) -> Property --- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp) - adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property adTestGen expr envGenerator = property $ do @@ -268,7 +192,6 @@ tests = checkSequential $ Group "AD" idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) - -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum ,("build1-sum", adTest term_build1_sum) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ |