summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs54
-rw-r--r--src/AST/Pretty.hs32
-rw-r--r--src/CHAD.hs22
-rw-r--r--src/Data.hs4
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs34
-rw-r--r--src/Interpreter/Rep.hs22
-rw-r--r--src/Language.hs2
-rw-r--r--test/Main.hs82
9 files changed, 176 insertions, 78 deletions
diff --git a/src/AST.hs b/src/AST.hs
index af137b2..6370148 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -147,9 +147,11 @@ data SOp a t where
ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
ONot :: SOp (TScal TBool) (TScal TBool)
- OIf :: SOp (TScal TBool) (TEither TNil TNil)
+ OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
+ OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
+ OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right
ORound64 :: SOp (TScal TF64) (TScal TI64)
OToFl64 :: SOp (TScal TI64) (TScal TF64)
deriving instance Show (SOp a t)
@@ -163,6 +165,8 @@ opt2 = \case
OLe _ -> STScal STBool
OEq _ -> STScal STBool
ONot -> STScal STBool
+ OAnd -> STScal STBool
+ OOr -> STScal STBool
OIf -> STEither STNil STNil
ORound64 -> STScal STI64
OToFl64 -> STScal STF64
@@ -206,23 +210,23 @@ typeOf = \case
EError t _ -> t
-unSNat :: SNat n -> Nat
-unSNat SZ = Z
-unSNat (SS n) = S (unSNat n)
+-- unSNat :: SNat n -> Nat
+-- unSNat SZ = Z
+-- unSNat (SS n) = S (unSNat n)
-unSTy :: STy t -> Ty
-unSTy = \case
- STNil -> TNil
- STPair a b -> TPair (unSTy a) (unSTy b)
- STEither a b -> TEither (unSTy a) (unSTy b)
- STMaybe t -> TMaybe (unSTy t)
- STArr n t -> TArr (unSNat n) (unSTy t)
- STScal t -> TScal (unSScalTy t)
- STAccum t -> TAccum (unSTy t)
+-- unSTy :: STy t -> Ty
+-- unSTy = \case
+-- STNil -> TNil
+-- STPair a b -> TPair (unSTy a) (unSTy b)
+-- STEither a b -> TEither (unSTy a) (unSTy b)
+-- STMaybe t -> TMaybe (unSTy t)
+-- STArr n t -> TArr (unSNat n) (unSTy t)
+-- STScal t -> TScal (unSScalTy t)
+-- STAccum t -> TAccum (unSTy t)
-unSList :: SList STy env -> [Ty]
-unSList SNil = []
-unSList (SCons t l) = unSTy t : unSList l
+-- unSEnv :: SList STy env -> [Ty]
+-- unSEnv SNil = []
+-- unSEnv (SCons t l) = unSTy t : unSEnv l
unSScalTy :: SScalTy t -> ScalTy
unSScalTy = \case
@@ -335,9 +339,25 @@ sscaltyKnown STF32 = Dict
sscaltyKnown STF64 = Dict
sscaltyKnown STBool = Dict
+envKnown :: SList STy env -> Dict (KnownEnv env)
+envKnown SNil = Dict
+envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
+
ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
ebuildUp1 n sh size f =
EBuild ext (SS n) (EPair ext sh size) $
let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
(EFst ext arg)
+
+eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eidxEq SZ _ _ = EConst ext STBool True
+eidxEq (SS n) a b
+ | let ty = tTup (sreplicate (SS n) tIx)
+ = ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EOp ext OAnd $ EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ)))
+ (ESnd ext (EVar ext ty IZ))))
+ (eidxEq n (EFst ext (EVar ext ty (IS IZ)))
+ (EFst ext (EVar ext ty IZ)))
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index b50506a..acd0dc3 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -6,7 +6,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (ppExpr) where
+module AST.Pretty (ppExpr, ppTy) where
import Control.Monad (ap)
import Data.List (intersperse)
@@ -42,10 +42,10 @@ genNameIfUsedIn' prefix ty idx ex
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = genNameIfUsedIn' "x"
-ppExpr :: SList STy env -> Expr x env t -> String
+ppExpr :: SList f env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
where
- mkVal :: SList STy env -> M (SVal env)
+ mkVal :: SList f env -> M (SVal env)
mkVal SNil = return SNil
mkVal (SCons _ v) = do
val <- mkVal v
@@ -112,14 +112,14 @@ ppExpr' d val = \case
EBuild1 _ a b -> do
a' <- ppExpr' 11 val a
- name <- genNameIfUsedIn (STScal STI64) IZ b
+ name <- genNameIfUsedIn' "i" (STScal STI64) IZ b
b' <- ppExpr' 0 (Const name `SCons` val) b
return $ showParen (d > 10) $
showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
- name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b
+ name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
e' <- ppExpr' 0 (Const name `SCons` val) b
return $ showParen (d > 10) $
showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
@@ -195,7 +195,7 @@ ppExpr' d val = \case
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ showParen (d > 10) $
- showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
+ showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
EZero _ -> return $ showString "zero"
@@ -243,6 +243,26 @@ operator OLt{} = (Infix, "<")
operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
+operator OAnd = (Infix, "&&")
+operator OOr = (Infix, "||")
operator OIf = (Prefix, "ifB")
operator ORound64 = (Prefix, "round")
operator OToFl64 = (Prefix, "toFl64")
+
+ppTy :: Int -> STy t -> String
+ppTy d ty = ppTys d ty ""
+
+ppTys :: Int -> STy t -> ShowS
+ppTys _ STNil = showString "1"
+ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b
+ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b
+ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t
+ppTys d (STArr n t) = showParen (d > 10) $
+ showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t
+ppTys _ (STScal sty) = showString $ case sty of
+ STI32 -> "i32"
+ STI64 -> "i64"
+ STF32 -> "f32"
+ STF64 -> "f64"
+ STBool -> "bool"
+ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t
diff --git a/src/CHAD.hs b/src/CHAD.hs
index e77dbe7..dda434c 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -545,6 +545,8 @@ d1op (OLt t) e = EOp ext (OLt t) e
d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
+d1op OAnd e = EOp ext OAnd e
+d1op OOr e = EOp ext OOr e
d1op OIf e = EOp ext OIf e
d1op ORound64 e = EOp ext ORound64 e
d1op OToFl64 e = EOp ext OToFl64 e
@@ -564,6 +566,8 @@ d2op op = case op of
OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
ONot -> Linear $ \_ -> ENil ext
+ OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
OToFl64 -> Linear $ \_ -> ENil ext
@@ -1078,15 +1082,19 @@ drev des = \case
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
, STArr n eltty <- typeOf e
- , Refl <- indexTupD1Id n ->
+ , Refl <- indexTupD1Id n
+ , let tIxN = tTup (sreplicate n tIx) ->
Ret (binds `BPush` (STArr n (d1 eltty), e1)
- `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ)))
- (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ))
- (weakenExpr (WSink .> WSink) ei1))
+ `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
+ `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ))
- (EVar ext (d2 eltty) (IS IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) $
+ ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ)))))
+ (EVar ext (d2 eltty) (IS (IS IZ)))
+ (EZero eltty)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
EShape _ e
-- Allowed to ignore e2 here because the output of EShape is discrete,
diff --git a/src/Data.hs b/src/Data.hs
index e951ef2..c5d6219 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -30,6 +30,10 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list
slistMap _ SNil = SNil
slistMap f (SCons x list) = SCons (f x) (slistMap f list)
+unSList :: (forall t. f t -> a) -> SList f list -> [a]
+unSList _ SNil = []
+unSList f (x `SCons` l) = f x : unSList f l
+
sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
sappend SNil l = l
sappend (SCons x xs) l = SCons x (sappend xs l)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index f02b93e..4f84e8d 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -99,6 +99,8 @@ dop = \case
(binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
(EOp ext (OEq t))
ONot -> EOp ext ONot
+ OAnd -> EOp ext OAnd
+ OOr -> EOp ext OOr
OIf -> EOp ext OIf
ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 3fb5d7b..3d6f33d 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
@@ -23,13 +24,14 @@ import Data.Char (isSpace)
import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
-import GHC.Stack (HasCallStack)
+import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
import Array
import AST
+import AST.Pretty
import CHAD.Types
import Data
import Interpreter.Rep
@@ -42,14 +44,25 @@ newtype AcM s a = AcM { unAcM :: IO a }
runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
+acmDebugLog :: String -> AcM s ()
+acmDebugLog s = AcM (hPutStrLn stderr s)
+
interpret :: Ex '[] t -> Rep t
interpret = interpretOpen SNil
interpretOpen :: SList Value env -> Ex env t -> Rep t
-interpretOpen env e = runAcM (interpret' env e)
-
-interpret' :: forall env t s. HasCallStack => SList Value env -> Ex env t -> AcM s (Rep t)
-interpret' env = \case
+interpretOpen env e = runAcM (let ?depth = 0 in interpret' env e)
+
+interpret' :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' env e = do
+ let dep = ?depth
+ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ ppExpr env e
+ res <- let ?depth = dep + 1 in interpret'Rec env e
+ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res ""
+ return res
+
+interpret'Rec :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t)
+interpret'Rec env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
ELet _ a b -> do
x <- interpret' env a
@@ -125,11 +138,20 @@ interpretOp op arg = case op of
ONeg st -> numericIsNum st $ negate arg
OLt st -> numericIsNum st $ uncurry (<) arg
OLe st -> numericIsNum st $ uncurry (<=) arg
- OEq st -> numericIsNum st $ uncurry (==) arg
+ OEq st -> styIsEq st $ uncurry (==) arg
ONot -> not arg
+ OAnd -> uncurry (&&) arg
+ OOr -> uncurry (||) arg
OIf -> if arg then Left () else Right ()
ORound64 -> round arg
OToFl64 -> fromIntegral arg
+ where
+ styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
+ styIsEq STI32 = id
+ styIsEq STI64 = id
+ styIsEq STF32 = id
+ styIsEq STF64 = id
+ styIsEq STBool = id
zeroD2 :: STy t -> Rep (D2 t)
zeroD2 typ = case typ of
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 5c20183..baf38fc 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -3,6 +3,8 @@
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
+import Data.List (intersperse)
+import Data.Foldable (toList)
import Data.IORef
import GHC.TypeError
@@ -53,3 +55,23 @@ vPair = liftV2 (,)
vUnpair :: Value (TPair a b) -> (Value a, Value b)
vUnpair (Value (x, y)) = (Value x, Value y)
+
+showValue :: Int -> STy t -> Rep t -> ShowS
+showValue _ STNil () = showString "()"
+showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
+showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
+showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
+showValue _ (STMaybe _) Nothing = showString "Nothing"
+showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
+showValue d (STArr _ t) arr = showParen (d > 10) $
+ showString "arrayFromList " . showsPrec 11 (arrayShape arr)
+ . showString " ["
+ . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
+ . showString "]"
+showValue _ (STScal sty) x = case sty of
+ STF32 -> shows x
+ STF64 -> shows x
+ STI32 -> shows x
+ STI64 -> shows x
+ STBool -> shows x
+showValue _ STAccum{} _ = error "Cannot show accumulators"
diff --git a/src/Language.hs b/src/Language.hs
index a025236..3a4a36c 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -126,5 +126,7 @@ not_ = oper ONot
-- | The "_" variables in scope are unusable and should be ignored. With a
-- weakening function on NExprs they could be hidden.
+--
+-- The first alternative is the True case; the second is the False case.
if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t
if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b)
diff --git a/test/Main.hs b/test/Main.hs
index e325b64..ab01e89 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -14,15 +14,12 @@ module Main where
import Data.Bifunctor
-- import qualified Data.Dependent.Map as DMap
-- import Data.Dependent.Map (DMap)
-import Data.Foldable (toList)
-import Data.List (intercalate, intersperse)
+import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Main
-import Debug.Trace
-
import Array
import AST
import AST.Pretty
@@ -34,6 +31,7 @@ import ForwardAD
import Interpreter
import Interpreter.Rep
import Language
+import Simplify
type family MapMerge env where
@@ -48,21 +46,33 @@ mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
mapMergeOnlyMerge SNil = Refl
mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
-gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
-gradientByCHAD = \env term input ->
+primalEnv :: SList STy env' -> SList STy (D1E env')
+primalEnv SNil = SNil
+primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
+
+diffCHAD :: Int -> SList STy env -> Ex env (TScal TF64)
+ -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env)))
+diffCHAD = \simplIters env term ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
let descr = makeMergeDescr env
- dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
- input1 = toPrimalE env input
- (_out, grad) = interpretOpen input1 dterm
- in (if False then trace ("gradientByCHAD: Differentiated term:\n" ++ ppExpr (primalEnv env) dterm ++ "\n\n\n") else id) $
- unTup vUnpair (d2e env) (Value grad)
+ in case envKnown (primalEnv env) of
+ Dict -> simplifyN simplIters $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
+-- In addition to the gradient, also returns the pretty-printed differentiated term.
+gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (D2E env))
+gradientByCHAD = \simplIters env term input ->
+ case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
+ (Refl, Refl) ->
+ let dterm = diffCHAD simplIters env term
+ input1 = toPrimalE env input
+ (_out, grad) = interpretOpen input1 dterm
+ in (ppExpr (primalEnv env) dterm, unTup vUnpair (d2e env) (Value grad))
+ where
toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
toPrimalE SNil SNil = SNil
toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
@@ -77,12 +87,9 @@ gradientByCHAD = \env term input ->
STScal _ -> id
STAccum{} -> error "Accumulators not allowed in input program"
- primalEnv :: SList STy env' -> SList STy (D1E env')
- primalEnv SNil = SNil
- primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
-
-gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
-gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
+-- In addition to the gradient, also returns the pretty-printed differentiated term.
+gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, SList Value (TanE env))
+gradientByCHAD' = \simplIters env term input -> toTanE env input <$> gradientByCHAD simplIters env term input
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
toTanE SNil SNil SNil = SNil
@@ -183,26 +190,6 @@ genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
-- shapes <- DMap.traverseWithKey _ constrs
-- genEnvTemplateExact shapes env
-showValue :: Int -> STy t -> Rep t -> ShowS
-showValue _ STNil () = showString "()"
-showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")"
-showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x
-showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
-showValue _ (STMaybe _) Nothing = showString "Nothing"
-showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
-showValue d (STArr _ t) arr = showParen (d > 10) $
- showString "arrayFromList " . showsPrec 11 (arrayShape arr)
- . showString " ["
- . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
- . showString "]"
-showValue _ (STScal sty) x = case sty of
- STF32 -> shows x
- STF64 -> shows x
- STI32 -> shows x
- STI64 -> shows x
- STBool -> shows x
-showValue _ STAccum{} _ = error "Cannot show accumulators"
-
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
where
@@ -224,15 +211,27 @@ adTestGen expr envGenerator = property $ do
let env = knownEnv @env
input <- forAllWith (showEnv env) envGenerator
let gradFwd = gradientByForward knownEnv expr input
- gradCHAD = gradientByCHAD' knownEnv expr input
+ (ppdterm, gradCHAD) = gradientByCHAD' 0 knownEnv expr input
+ (ppdterm_S, gradCHAD_S) = gradientByCHAD' 20 knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
- diff scCHAD (\x y -> and (zipWith closeIsh x y)) scFwd
+ scCHAD_S = envScalars env gradCHAD_S
+ annotate (concat (unSList (\t -> ppTy 0 t ++ " -> ") env) ++ ppTy 0 (typeOf expr))
+ annotate (ppExpr knownEnv expr)
+ annotate ppdterm
+ annotate ppdterm_S
+ diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
+ diff scFwd (\x y -> and (zipWith closeIsh x y)) scCHAD_S
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
envScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ envScalars ts xs
+term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_build1_sum = fromNamed $ lambda #x $ body $
+ idx0 $ sum1i $
+ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
+
tests :: IO Bool
tests = checkSequential $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
@@ -256,9 +255,8 @@ tests = checkSequential $ Group "AD"
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)
- ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $
- idx0 $ sum1i $
- build (SS SZ) (shape #x) $ #idx :-> #x ! #idx)
+ -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum
+ ,("build1-sum", adTest term_build1_sum)
,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
idx0 $ sum1i . sum1i $