diff options
-rw-r--r-- | src/Simplify.hs | 87 | ||||
-rw-r--r-- | test/Main.hs | 36 |
2 files changed, 73 insertions, 50 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index 5829a8b..cfbdbb9 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -9,6 +9,9 @@ {-# LANGUAGE TypeOperators #-} module Simplify where +import Data.Function (fix) +import Data.Monoid (Any(..)) + import AST import AST.Count import Data @@ -19,23 +22,30 @@ simplifyN 0 = id simplifyN n = simplifyN (n - 1) . simplify simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplify = let ?accumInScope = checkAccumInScope @env knownEnv in simplify' +simplify = let ?accumInScope = checkAccumInScope @env knownEnv in snd . simplify' + +simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t +simplifyFix = + let ?accumInScope = checkAccumInScope @env knownEnv + in fix $ \loop e -> + let (Any act, e') = simplify' e + in if act then loop e' else e' -simplify' :: (?accumInScope :: Bool) => Ex env t -> Ex env t +simplify' :: (?accumInScope :: Bool) => Ex env t -> (Any, Ex env t) simplify' = \case -- inlining ELet _ rhs body | cheapExpr rhs - -> simplify' (subst1 rhs body) + -> acted $ simplify' (subst1 rhs body) | Occ lexOcc runOcc <- occCount IZ body , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not - -> simplify' (subst1 rhs body) + -> acted $ simplify' (subst1 rhs body) -- let splitting ELet _ (EPair _ a b) body -> - simplify' $ + acted $ simplify' $ ELet ext a $ ELet ext (weakenExpr WSink b) $ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) @@ -44,10 +54,10 @@ simplify' = \case -- let rotation ELet _ (ELet _ rhs a) b -> - simplify' $ + acted $ simplify' $ ELet ext rhs $ ELet ext a $ - weakenExpr (WCopy WSink) (simplify' b) + weakenExpr (WCopy WSink) (snd (simplify' b)) -- beta rules for products EFst _ (EPair _ e _) -> simplify' e @@ -72,36 +82,39 @@ simplify' = \case -- TODO: accum of zero, plus of zero - EVar _ t i -> EVar ext t i - ELet _ a b -> ELet ext (simplify' a) (simplify' b) - EPair _ a b -> EPair ext (simplify' a) (simplify' b) - EFst _ e -> EFst ext (simplify' e) - ESnd _ e -> ESnd ext (simplify' e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext t (simplify' e) - EInr _ t e -> EInr ext t (simplify' e) - ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b) - ENothing _ t -> ENothing ext t - EJust _ e -> EJust ext (simplify' e) - EMaybe _ a b e -> EMaybe ext (simplify' a) (simplify' b) (simplify' e) - EConstArr _ n t v -> EConstArr ext n t v - EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) - EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b) - EFold1Inner _ a b c -> EFold1Inner ext (simplify' a) (simplify' b) (simplify' c) - ESum1Inner _ e -> ESum1Inner ext (simplify' e) - EUnit _ e -> EUnit ext (simplify' e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (simplify' a) (simplify' b) - EConst _ t v -> EConst ext t v - EIdx0 _ e -> EIdx0 ext (simplify' e) - EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) - EIdx _ a b -> EIdx ext (simplify' a) (simplify' b) - EShape _ e -> EShape ext (simplify' e) - EOp _ op e -> EOp ext op (simplify' e) - EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) - EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3) - EZero t -> EZero t - EPlus t a b -> EPlus t (simplify' a) (simplify' b) - EError t s -> EError t s + EVar _ t i -> pure $ EVar ext t i + ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b + EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b + EFst _ e -> EFst ext <$> simplify' e + ESnd _ e -> ESnd ext <$> simplify' e + ENil _ -> pure $ ENil ext + EInl _ t e -> EInl ext t <$> simplify' e + EInr _ t e -> EInr ext t <$> simplify' e + ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b + ENothing _ t -> pure $ ENothing ext t + EJust _ e -> EJust ext <$> simplify' e + EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e + EConstArr _ n t v -> pure $ EConstArr ext n t v + EBuild1 _ a b -> EBuild1 ext <$> simplify' a <*> simplify' b + EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b + EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c + ESum1Inner _ e -> ESum1Inner ext <$> simplify' e + EUnit _ e -> EUnit ext <$> simplify' e + EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b + EConst _ t v -> pure $ EConst ext t v + EIdx0 _ e -> EIdx0 ext <$> simplify' e + EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b + EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b + EShape _ e -> EShape ext <$> simplify' e + EOp _ op e -> EOp ext op <$> simplify' e + EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) + EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 + EZero t -> pure $ EZero t + EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b + EError t s -> pure $ EError t s + +acted :: (Any, a) -> (Any, a) +acted (_, x) = (Any True, x) cheapExpr :: Expr x env t -> Bool cheapExpr = \case diff --git a/test/Main.hs b/test/Main.hs index a3a614a..e7dda69 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -50,21 +50,26 @@ primalEnv :: SList STy env' -> SList STy (D1E env') primalEnv SNil = SNil primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env -diffCHAD :: Int -> SList STy env -> Ex env (TScal TF64) +data SimplIters = SimplIters Int | SimplFix + deriving (Show) + +diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env))) diffCHAD = \simplIters env term -> - case (mapMergeNoAccum env, mapMergeOnlyMerge env) of - (Refl, Refl) -> + case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of + (Refl, Refl, Dict) -> let descr = makeMergeDescr env - in case envKnown (primalEnv env) of - Dict -> simplifyN simplIters $ freezeRet descr (drev descr term) (EConst ext STF64 1.0) + simpl = case simplIters of + SimplIters n -> simplifyN n + SimplFix -> simplifyFix + in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD :: forall env. Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) +gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> case (mapMergeNoAccum env, mapMergeOnlyMerge env) of (Refl, Refl) -> @@ -88,7 +93,7 @@ gradientByCHAD = \simplIters env term input -> STAccum{} -> error "Accumulators not allowed in input program" -- In addition to the gradient, also returns the pretty-printed differentiated term. -gradientByCHAD' :: Int -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) +gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input where toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) @@ -212,8 +217,9 @@ adTestGen expr envGenerator = property $ do input <- forAllWith (showEnv env) envGenerator let outPrimal = interpretOpen False input expr gradFwd = gradientByForward knownEnv expr input - (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' 0 knownEnv expr input - (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' 20 knownEnv expr input + (ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input + (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input + (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input scFwd = envScalars env gradFwd scCHAD = envScalars env gradCHAD scCHAD_S = envScalars env gradCHAD_S @@ -221,6 +227,7 @@ adTestGen expr envGenerator = property $ do annotate (ppExpr knownEnv expr) annotate ppdterm annotate ppdterm_S + diff ppdterm_S20 (==) ppdterm_S diff outChad closeIsh outChad_S diff outPrimal closeIsh outChad_S diff scCHAD (\x y -> and (zipWith closeIsh x y)) scCHAD_S @@ -235,6 +242,12 @@ term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64) +term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ + let_ #p (pair #x #y) $ + let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ + fst_ #q * #x + snd_ #q * fst_ #p + tests :: IO Bool tests = checkSequential $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) @@ -246,10 +259,7 @@ tests = checkSequential $ Group "AD" ,("sum-replicate", adTest $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x) - ,("pairs", adTest $ fromNamed $ lambda #x $ lambda #y $ body $ - let_ #p (pair #x #y) $ - let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ - fst_ #q * #x + snd_ #q * fst_ #p) + ,("pairs", adTest term_pairs) ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0) |