diff options
-rw-r--r-- | src/AST.hs | 54 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 32 | ||||
-rw-r--r-- | src/CHAD.hs | 22 | ||||
-rw-r--r-- | src/Data.hs | 4 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 34 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 22 | ||||
-rw-r--r-- | src/Language.hs | 2 | ||||
-rw-r--r-- | test/Main.hs | 82 |
9 files changed, 176 insertions, 78 deletions
@@ -147,9 +147,11 @@ data SOp a t where ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) - OIf :: SOp (TScal TBool) (TEither TNil TNil) + OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right ORound64 :: SOp (TScal TF64) (TScal TI64) OToFl64 :: SOp (TScal TI64) (TScal TF64) deriving instance Show (SOp a t) @@ -163,6 +165,8 @@ opt2 = \case OLe _ -> STScal STBool OEq _ -> STScal STBool ONot -> STScal STBool + OAnd -> STScal STBool + OOr -> STScal STBool OIf -> STEither STNil STNil ORound64 -> STScal STI64 OToFl64 -> STScal STF64 @@ -206,23 +210,23 @@ typeOf = \case EError t _ -> t -unSNat :: SNat n -> Nat -unSNat SZ = Z -unSNat (SS n) = S (unSNat n) +-- unSNat :: SNat n -> Nat +-- unSNat SZ = Z +-- unSNat (SS n) = S (unSNat n) -unSTy :: STy t -> Ty -unSTy = \case - STNil -> TNil - STPair a b -> TPair (unSTy a) (unSTy b) - STEither a b -> TEither (unSTy a) (unSTy b) - STMaybe t -> TMaybe (unSTy t) - STArr n t -> TArr (unSNat n) (unSTy t) - STScal t -> TScal (unSScalTy t) - STAccum t -> TAccum (unSTy t) +-- unSTy :: STy t -> Ty +-- unSTy = \case +-- STNil -> TNil +-- STPair a b -> TPair (unSTy a) (unSTy b) +-- STEither a b -> TEither (unSTy a) (unSTy b) +-- STMaybe t -> TMaybe (unSTy t) +-- STArr n t -> TArr (unSNat n) (unSTy t) +-- STScal t -> TScal (unSScalTy t) +-- STAccum t -> TAccum (unSTy t) -unSList :: SList STy env -> [Ty] -unSList SNil = [] -unSList (SCons t l) = unSTy t : unSList l +-- unSEnv :: SList STy env -> [Ty] +-- unSEnv SNil = [] +-- unSEnv (SCons t l) = unSTy t : unSEnv l unSScalTy :: SScalTy t -> ScalTy unSScalTy = \case @@ -335,9 +339,25 @@ sscaltyKnown STF32 = Dict sscaltyKnown STF64 = Dict sscaltyKnown STBool = Dict +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict + ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) (EFst ext arg) + +eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eidxEq SZ _ _ = EConst ext STBool True +eidxEq (SS n) a b + | let ty = tTup (sreplicate (SS n) tIx) + = ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EOp ext OAnd $ EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) + (ESnd ext (EVar ext ty IZ)))) + (eidxEq n (EFst ext (EVar ext ty (IS IZ))) + (EFst ext (EVar ext ty IZ))) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index b50506a..acd0dc3 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -6,7 +6,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (ppExpr) where +module AST.Pretty (ppExpr, ppTy) where import Control.Monad (ap) import Data.List (intersperse) @@ -42,10 +42,10 @@ genNameIfUsedIn' prefix ty idx ex genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" -ppExpr :: SList STy env -> Expr x env t -> String +ppExpr :: SList f env -> Expr x env t -> String ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" where - mkVal :: SList STy env -> M (SVal env) + mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil mkVal (SCons _ v) = do val <- mkVal v @@ -112,14 +112,14 @@ ppExpr' d val = \case EBuild1 _ a b -> do a' <- ppExpr' 11 val a - name <- genNameIfUsedIn (STScal STI64) IZ b + name <- genNameIfUsedIn' "i" (STScal STI64) IZ b b' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" EBuild _ n a b -> do a' <- ppExpr' 11 val a - name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b + name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" @@ -195,7 +195,7 @@ ppExpr' d val = \case e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ - showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' + showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' EZero _ -> return $ showString "zero" @@ -243,6 +243,26 @@ operator OLt{} = (Infix, "<") operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") +operator OAnd = (Infix, "&&") +operator OOr = (Infix, "||") operator OIf = (Prefix, "ifB") operator ORound64 = (Prefix, "round") operator OToFl64 = (Prefix, "toFl64") + +ppTy :: Int -> STy t -> String +ppTy d ty = ppTys d ty "" + +ppTys :: Int -> STy t -> ShowS +ppTys _ STNil = showString "1" +ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b +ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b +ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t +ppTys d (STArr n t) = showParen (d > 10) $ + showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t +ppTys _ (STScal sty) = showString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + STBool -> "bool" +ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t diff --git a/src/CHAD.hs b/src/CHAD.hs index e77dbe7..dda434c 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -545,6 +545,8 @@ d1op (OLt t) e = EOp ext (OLt t) e d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e d1op OIf e = EOp ext OIf e d1op ORound64 e = EOp ext ORound64 e d1op OToFl64 e = EOp ext OToFl64 e @@ -564,6 +566,8 @@ d2op op = case op of OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 OToFl64 -> Linear $ \_ -> ENil ext @@ -1078,15 +1082,19 @@ drev des = \case | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil , STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n -> + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) + `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) + `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) - (EVar ext (d2 eltty) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) $ + ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ))))) + (EVar ext (d2 eltty) (IS (IS IZ))) + (EZero eltty)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, diff --git a/src/Data.hs b/src/Data.hs index e951ef2..c5d6219 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -30,6 +30,10 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list slistMap _ SNil = SNil slistMap f (SCons x list) = SCons (f x) (slistMap f list) +unSList :: (forall t. f t -> a) -> SList f list -> [a] +unSList _ SNil = [] +unSList f (x `SCons` l) = f x : unSList f l + sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) sappend SNil l = l sappend (SCons x xs) l = SCons x (sappend xs l) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index f02b93e..4f84e8d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -99,6 +99,8 @@ dop = \case (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) (EOp ext (OEq t)) ONot -> EOp ext ONot + OAnd -> EOp ext OAnd + OOr -> EOp ext OOr OIf -> EOp ext OIf ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3fb5d7b..3d6f33d 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} @@ -23,13 +24,14 @@ import Data.Char (isSpace) import Data.Kind (Type) import Data.Int (Int64) import Data.IORef -import GHC.Stack (HasCallStack) +import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace import Array import AST +import AST.Pretty import CHAD.Types import Data import Interpreter.Rep @@ -42,14 +44,25 @@ newtype AcM s a = AcM { unAcM :: IO a } runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m +acmDebugLog :: String -> AcM s () +acmDebugLog s = AcM (hPutStrLn stderr s) + interpret :: Ex '[] t -> Rep t interpret = interpretOpen SNil interpretOpen :: SList Value env -> Ex env t -> Rep t -interpretOpen env e = runAcM (interpret' env e) - -interpret' :: forall env t s. HasCallStack => SList Value env -> Ex env t -> AcM s (Rep t) -interpret' env = \case +interpretOpen env e = runAcM (let ?depth = 0 in interpret' env e) + +interpret' :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret' env e = do + let dep = ?depth + acmDebugLog $ replicate dep ' ' ++ "ev: " ++ ppExpr env e + res <- let ?depth = dep + 1 in interpret'Rec env e + acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" + return res + +interpret'Rec :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret'Rec env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a @@ -125,11 +138,20 @@ interpretOp op arg = case op of ONeg st -> numericIsNum st $ negate arg OLt st -> numericIsNum st $ uncurry (<) arg OLe st -> numericIsNum st $ uncurry (<=) arg - OEq st -> numericIsNum st $ uncurry (==) arg + OEq st -> styIsEq st $ uncurry (==) arg ONot -> not arg + OAnd -> uncurry (&&) arg + OOr -> uncurry (||) arg OIf -> if arg then Left () else Right () ORound64 -> round arg OToFl64 -> fromIntegral arg + where + styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r + styIsEq STI32 = id + styIsEq STI64 = id + styIsEq STF32 = id + styIsEq STF64 = id + styIsEq STBool = id zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 5c20183..baf38fc 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -3,6 +3,8 @@ {-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where +import Data.List (intersperse) +import Data.Foldable (toList) import Data.IORef import GHC.TypeError @@ -53,3 +55,23 @@ vPair = liftV2 (,) vUnpair :: Value (TPair a b) -> (Value a, Value b) vUnpair (Value (x, y)) = (Value x, Value y) + +showValue :: Int -> STy t -> Rep t -> ShowS +showValue _ STNil () = showString "()" +showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y +showValue _ (STMaybe _) Nothing = showString "Nothing" +showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x +showValue d (STArr _ t) arr = showParen (d > 10) $ + showString "arrayFromList " . showsPrec 11 (arrayShape arr) + . showString " [" + . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) + . showString "]" +showValue _ (STScal sty) x = case sty of + STF32 -> shows x + STF64 -> shows x + STI32 -> shows x + STI64 -> shows x + STBool -> shows x +showValue _ STAccum{} _ = error "Cannot show accumulators" diff --git a/src/Language.hs b/src/Language.hs index a025236..3a4a36c 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -126,5 +126,7 @@ not_ = oper ONot -- | The "_" variables in scope are unusable and should be ignored. With a -- weakening function on NExprs they could be hidden. +-- +-- The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b) diff --git a/test/Main.hs b/test/Main.hs index e325b64..ab01e89 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -14,15 +14,12 @@ module Main where import Data.Bifunctor -- import qualified Data.Dependent.Map as DMap -- import Data.Dependent.Map (DMap) -import Data.Foldable (toList) -import Data.List (intercalate, intersperse) +import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Hedgehog.Main -import Debug.Trace - import Array import AST import AST.Pretty @@ -34,6 +31,7 @@ import ForwardAD import Interpreter import Interpreter.Rep import Language +import Simplify type family MapMerge env where @@ -48,21 +46,33 @@ mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env mapMergeOnlyMerge SNil = Refl mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl -gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) -gradientByCHAD = \env term input -> +primalEnv :: SList STy env' -> SList STy (D1E env') +primalEnv SNil = SNil +primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env + +diffCHAD :: Int -> SList STy env -> Ex env (TScal TF64) + -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env))) +diffCHAD = \simplIters env term -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> let descr = makeMergeDescr env - dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) - input1 = toPrimalE env input - (_out, grad) = interpretOpen input1 dterm - in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $ - unTup vUnpair (d2e env) (Value grad) + in case envKnown (primalEnv env) of + Dict -> simplifyN simplIters $ freezeRet descr (drev descr term) (EConst ext STF64 1.0) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) +-- In addition to the gradient, also returns the pretty-printed differentiated term. +gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (D2E env)) +gradientByCHAD = \simplIters env term input -> + case (mapMergeNoAccum env, mapMergeOnlyMerge env) of + (Refl, Refl) -> + let dterm = diffCHAD simplIters env term + input1 = toPrimalE env input + (_out, grad) = interpretOpen input1 dterm + in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad)) + where toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') toPrimalE SNil SNil = SNil toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp @@ -77,12 +87,9 @@ gradientByCHAD = \env term input -> STScal _ -> id STAccum{} -> error "Accumulators not allowed in input program" - primalEnv :: SList STy env' -> SList STy (D1E env') - primalEnv SNil = SNil - primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env - -gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) -gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) +-- In addition to the gradient, also returns the pretty-printed differentiated term. +gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (TanE env)) +gradientByCHAD' = \simplIters env term input -> toTanE env input <$> gradientByCHAD simplIters env term input where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) toTanE SNil SNil SNil = SNil @@ -183,26 +190,6 @@ genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env -- shapes <- DMap.traverseWithKey _ constrs -- genEnvTemplateExact shapes env -showValue :: Int -> STy t -> Rep t -> ShowS -showValue _ STNil () = showString "()" -showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" -showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x -showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y -showValue _ (STMaybe _) Nothing = showString "Nothing" -showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x -showValue d (STArr _ t) arr = showParen (d > 10) $ - showString "arrayFromList " . showsPrec 11 (arrayShape arr) - . showString " [" - . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) - . showString "]" -showValue _ (STScal sty) x = case sty of - STF32 -> shows x - STF64 -> shows x - STI32 -> shows x - STI64 -> shows x - STBool -> shows x -showValue _ STAccum{} _ = error "Cannot show accumulators" - showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" where @@ -224,15 +211,27 @@ adTestGen expr envGenerator = property $ do let env = knownEnv @env input <- forAllWith (showEnv env) envGenerator let gradFwd = gradientByForward knownEnv expr input - gradCHAD = gradientByCHAD' knownEnv expr input + (ppdterm, gradCHAD) = gradientByCHAD' 0 knownEnv expr input + (ppdterm_S, gradCHAD_S) = gradientByCHAD' 20 knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD - diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd + scCHAD_S = envScalars env gradCHAD_S + annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr)) + annotate (ppExpr knownEnv expr) + annotate ppdterm + annotate ppdterm_S + diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S + diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs +term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64) +term_build1_sum = fromNamed $ lambda #x $ body $ + idx0 $ sum1i $ + build (SS SZ) (shape #x) $ #idx :-> #x ! #idx + tests :: IO Bool tests = checkSequential $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) @@ -256,9 +255,8 @@ tests = checkSequential $ Group "AD" idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) - ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $ - idx0 $ sum1i $ - build (SS SZ) (shape #x) $ #idx :-> #x ! #idx) + -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum + ,("build1-sum", adTest term_build1_sum) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ |