summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Simplify.hs87
-rw-r--r--test/Main.hs36
2 files changed, 73 insertions, 50 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 5829a8b..cfbdbb9 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -9,6 +9,9 @@
{-# LANGUAGE TypeOperators #-}
module Simplify where
+import Data.Function (fix)
+import Data.Monoid (Any(..))
+
import AST
import AST.Count
import Data
@@ -19,23 +22,30 @@ simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify
simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplify = let ?accumInScope = checkAccumInScope @env knownEnv in simplify'
+simplify = let ?accumInScope = checkAccumInScope @env knownEnv in snd . simplify'
+
+simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
+simplifyFix =
+ let ?accumInScope = checkAccumInScope @env knownEnv
+ in fix $ \loop e ->
+ let (Any act, e') = simplify' e
+ in if act then loop e' else e'
-simplify' :: (?accumInScope :: Bool) => Ex env t -> Ex env t
+simplify' :: (?accumInScope :: Bool) => Ex env t -> (Any, Ex env t)
simplify' = \case
-- inlining
ELet _ rhs body
| cheapExpr rhs
- -> simplify' (subst1 rhs body)
+ -> acted $ simplify' (subst1 rhs body)
| Occ lexOcc runOcc <- occCount IZ body
, ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply
|| (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not
- -> simplify' (subst1 rhs body)
+ -> acted $ simplify' (subst1 rhs body)
-- let splitting
ELet _ (EPair _ a b) body ->
- simplify' $
+ acted $ simplify' $
ELet ext a $
ELet ext (weakenExpr WSink b) $
subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
@@ -44,10 +54,10 @@ simplify' = \case
-- let rotation
ELet _ (ELet _ rhs a) b ->
- simplify' $
+ acted $ simplify' $
ELet ext rhs $
ELet ext a $
- weakenExpr (WCopy WSink) (simplify' b)
+ weakenExpr (WCopy WSink) (snd (simplify' b))
-- beta rules for products
EFst _ (EPair _ e _) -> simplify' e
@@ -72,36 +82,39 @@ simplify' = \case
-- TODO: accum of zero, plus of zero
- EVar _ t i -> EVar ext t i
- ELet _ a b -> ELet ext (simplify' a) (simplify' b)
- EPair _ a b -> EPair ext (simplify' a) (simplify' b)
- EFst _ e -> EFst ext (simplify' e)
- ESnd _ e -> ESnd ext (simplify' e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (simplify' e)
- EInr _ t e -> EInr ext t (simplify' e)
- ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b)
- ENothing _ t -> ENothing ext t
- EJust _ e -> EJust ext (simplify' e)
- EMaybe _ a b e -> EMaybe ext (simplify' a) (simplify' b) (simplify' e)
- EConstArr _ n t v -> EConstArr ext n t v
- EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
- EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)
- EFold1Inner _ a b c -> EFold1Inner ext (simplify' a) (simplify' b) (simplify' c)
- ESum1Inner _ e -> ESum1Inner ext (simplify' e)
- EUnit _ e -> EUnit ext (simplify' e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (simplify' a) (simplify' b)
- EConst _ t v -> EConst ext t v
- EIdx0 _ e -> EIdx0 ext (simplify' e)
- EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
- EIdx _ a b -> EIdx ext (simplify' a) (simplify' b)
- EShape _ e -> EShape ext (simplify' e)
- EOp _ op e -> EOp ext op (simplify' e)
- EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
- EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3)
- EZero t -> EZero t
- EPlus t a b -> EPlus t (simplify' a) (simplify' b)
- EError t s -> EError t s
+ EVar _ t i -> pure $ EVar ext t i
+ ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
+ EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
+ EFst _ e -> EFst ext <$> simplify' e
+ ESnd _ e -> ESnd ext <$> simplify' e
+ ENil _ -> pure $ ENil ext
+ EInl _ t e -> EInl ext t <$> simplify' e
+ EInr _ t e -> EInr ext t <$> simplify' e
+ ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
+ ENothing _ t -> pure $ ENothing ext t
+ EJust _ e -> EJust ext <$> simplify' e
+ EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
+ EConstArr _ n t v -> pure $ EConstArr ext n t v
+ EBuild1 _ a b -> EBuild1 ext <$> simplify' a <*> simplify' b
+ EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
+ EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c
+ ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
+ EUnit _ e -> EUnit ext <$> simplify' e
+ EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
+ EConst _ t v -> pure $ EConst ext t v
+ EIdx0 _ e -> EIdx0 ext <$> simplify' e
+ EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
+ EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
+ EShape _ e -> EShape ext <$> simplify' e
+ EOp _ op e -> EOp ext op <$> simplify' e
+ EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
+ EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
+ EZero t -> pure $ EZero t
+ EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b
+ EError t s -> pure $ EError t s
+
+acted :: (Any, a) -> (Any, a)
+acted (_, x) = (Any True, x)
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
diff --git a/test/Main.hs b/test/Main.hs
index a3a614a..e7dda69 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -50,21 +50,26 @@ primalEnv :: SList STy env' -> SList STy (D1E env')
primalEnv SNil = SNil
primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
-diffCHAD :: Int -> SList STy env -> Ex env (TScal TF64)
+data SimplIters = SimplIters Int | SimplFix
+ deriving (Show)
+
+diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64)
-> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env)))
diffCHAD = \simplIters env term ->
- case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
- (Refl, Refl) ->
+ case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of
+ (Refl, Refl, Dict) ->
let descr = makeMergeDescr env
- in case envKnown (primalEnv env) of
- Dict -> simplifyN simplIters $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
+ simpl = case simplIters of
+ SimplIters n -> simplifyN n
+ SimplFix -> simplifyFix
+ in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
+gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
(Refl, Refl) ->
@@ -88,7 +93,7 @@ gradientByCHAD = \simplIters env term input ->
STAccum{} -> error "Accumulators not allowed in input program"
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
+gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
where
toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
@@ -212,8 +217,9 @@ adTestGen expr envGenerator = property $ do
input <- forAllWith (showEnv env) envGenerator
let outPrimal = interpretOpen False input expr
gradFwd = gradientByForward knownEnv expr input
- (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' 0 knownEnv expr input
- (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' 20 knownEnv expr input
+ (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
+ (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
+ (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
scFwd = envScalars env gradFwd
scCHAD = envScalars env gradCHAD
scCHAD_S = envScalars env gradCHAD_S
@@ -221,6 +227,7 @@ adTestGen expr envGenerator = property $ do
annotate (ppExpr knownEnv expr)
annotate ppdterm
annotate ppdterm_S
+ diff ppdterm_S20 (==) ppdterm_S
diff outChad closeIsh outChad_S
diff outPrimal closeIsh outChad_S
diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S
@@ -235,6 +242,12 @@ term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
+term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
+term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #p (pair #x #y) $
+ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
+ fst_ #q * #x + snd_ #q * fst_ #p
+
tests :: IO Bool
tests = checkSequential $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
@@ -246,10 +259,7 @@ tests = checkSequential $ Group "AD"
,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $
idx0 $ sum1i $ replicate1i 10 #x)
- ,("pairs", adTest $ fromNamed $ lambda #x $ lambda #y $ body $
- let_ #p (pair #x #y) $
- let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
- fst_ #q * #x + snd_ #q * fst_ #p)
+ ,("pairs", adTest term_pairs)
,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0)