diff options
Diffstat (limited to 'src/Simplify.hs')
| -rw-r--r-- | src/Simplify.hs | 619 |
1 files changed, 0 insertions, 619 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs deleted file mode 100644 index 19d0c17..0000000 --- a/src/Simplify.hs +++ /dev/null @@ -1,619 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Simplify ( - simplifyN, simplifyFix, - SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, -) where - -import Control.Monad (ap) -import Data.Bifunctor (first) -import Data.Function (fix) -import Data.Monoid (Any(..)) - -import Debug.Trace - -import AST -import AST.Count -import AST.Pretty -import AST.Sparse.Types -import AST.UnMonoid (acPrjCompose) -import Data -import Simplify.TH - - -data SimplifyConfig = SimplifyConfig - { scLogging :: Bool - } - -defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig False - -simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t -simplifyN 0 = id -simplifyN n = simplifyN (n - 1) . simplify - -simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplify = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = defaultSimplifyConfig - in snd . runSM . simplify' - -simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in snd . runSM . simplify' - -simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplifyFix = simplifyFixWith defaultSimplifyConfig - -simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyFixWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in fix $ \loop e -> - let (act, e') = runSM (simplify' e) - in if act then loop e' else e' - --- | simplify monad -newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) - deriving (Functor) - -instance Applicative (SM tenv tt env t) where - pure x = SM (\_ -> (Any False, x)) - (<*>) = ap - -instance Monad (SM tenv tt env t) where - SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx - -runSM :: SM env t env t a -> (Bool, a) -runSM (SM f) = first getAny (f id) - -smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) -smReconstruct core = SM (\ctx -> (Any False, ctx core)) - -class Monad m => ActedMonad m where - tellActed :: m () - hideActed :: m a -> m a - liftActed :: (Any, a) -> m a - -instance ActedMonad ((,) Any) where - tellActed = (Any True, ()) - hideActed (_, x) = (Any False, x) - liftActed = id - -instance ActedMonad (SM tenv tt env t) where - tellActed = SM (\_ -> tellActed) - hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) - liftActed pair = SM (\_ -> pair) - --- more convenient in practice -acted :: ActedMonad m => m a -> m a -acted m = tellActed >> m - -within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a -within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) - -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) -simplify' expr - | scLogging ?config = do - res <- simplify'Rec expr - full <- smReconstruct res - let printed = ppExpr knownEnv full - replace a bs = concatMap (\x -> if x == a then bs else [x]) - str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed - | otherwise = "--- simplify step: " ++ printed - traceM str - return res - | otherwise = simplify'Rec expr - -simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) -simplify'Rec = \case - -- inlining - ELet _ rhs body - | cheapExpr rhs - -> acted $ simplify' (substInline rhs body) - - | Occ lexOcc runOcc <- occCount IZ body - , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply - || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not - -> acted $ simplify' (substInline rhs body) - - -- let splitting / let peeling - ELet _ (EPair _ a b) body -> - acted $ simplify' $ - ELet ext a $ - ELet ext (weakenExpr WSink b) $ - subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) - IS i -> EVar ext t (IS (IS i))) - body - ELet _ (EJust _ a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body - ELet _ (EInl _ t2 a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body - ELet _ (EInr _ t1 a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body - - -- let rotation - ELet _ (ELet _ rhs a) b -> do - b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b - acted $ simplify' $ - ELet ext rhs $ - ELet ext a $ - weakenExpr (WCopy WSink) b' - - -- beta rules for products - EFst _ (EPair _ e e') - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - ESnd _ (EPair _ e' e) - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - - -- beta rules for coproducts - ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) - ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) - - -- beta rules for maybe - EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 - EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 - - -- let floating - EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) - ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) - ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) - EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) - EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 sp (ELet _ rhs body) acc -> - acted $ simplify' $ - ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) - - -- let () = e in () ~> e - ELet _ e1 (ENil _) | STNil <- typeOf e1 -> - acted $ simplify' e1 - - -- map (\_ -> x) e ~> build (shape e) (\_ -> x) - EMap _ e1 e2 - | Occ Zero Zero <- occCount IZ e1 - , STArr n _ <- typeOf e2 -> - acted $ simplify' $ - EBuild ext n (EShape ext e2) $ - subst (\_ t' -> \case IZ -> error "Unused variable was used" - IS i -> EVar ext t' (IS i)) - e1 - - -- vertical fusion - EMap _ e1 (EMap _ e2 e3) -> - acted $ simplify' $ - EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 - - -- projection down-commuting - EFst _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (EFst ext e2) (EFst ext e3) - ESnd _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (ESnd ext e2) (ESnd ext e3) - EFst _ (EMaybe _ e1 e2 e3) -> - acted $ simplify' $ - EMaybe ext (EFst ext e1) (EFst ext e2) e3 - ESnd _ (EMaybe _ e1 e2 e3) -> - acted $ simplify' $ - EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 - - -- TODO: more array indexing - EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 - EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 - EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) - EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 - - -- TODO: more array shape - EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 - EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) - - -- TODO: more constant folding - EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) - EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) - - -- inline cheap array constructors - ELet _ (EReplicate1Inner _ e1 e2) e3 -> - acted $ simplify' $ - ELet ext (EPair ext e1 e2) $ - let v = EVar ext (STPair tIx (typeOf e2)) IZ - in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3 - -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is - -- -- cheap, which it can't be because (!) is not cheap if you do AD after. - -- -- Should do proper SoA representation. - -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 -> - -- acted $ simplify' $ - -- ELet ext e1 $ - -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3 - - -- eta rule for unit - e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> - case e of - ENil _ -> return e - _ -> acted $ return (ENil ext) - - EBuild _ SZ _ e -> - acted $ simplify' $ EUnit ext (substInline (ENil ext) e) - - -- monoid rules - EAccum _ t p e1 sp e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc - simplifyOHT (OneHotTerm SAID t p e1' sp e2') - (acted $ return (ENil ext)) - (\sp' (InContext w wrap e) -> do - e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e - return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) - (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do - -- The acted management here is a hideous mess. - e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' - e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' - return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) - EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e - EOneHot _ t p e1 e2 -> do - e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 - e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') - (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\sp' (InContext _ wrap e) -> - case isDense t sp' of - Just Refl -> do - e' <- hideActed $ within wrap $ simplify' e - return (wrap e') - Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") - (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> - case isDense (acPrjTy p' t') sp' of - Just Refl -> do - e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' - e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' - return (wrap $ EOneHot ext t' p' e1''' e2''') - Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") - - -- type-specific equations for plus - EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> - acted $ return (ENil ext) - - EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> - acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) - - EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> - acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) - EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> - acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) - EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e - EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e - - EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> - acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e - - -- fallback recursion - EVar _ t i -> pure $ EVar ext t i - ELet _ a b -> [simprec| ELet ext *a *b |] - EPair _ a b -> [simprec| EPair ext *a *b |] - EFst _ e -> [simprec| EFst ext *e |] - ESnd _ e -> [simprec| ESnd ext *e |] - ENil _ -> pure $ ENil ext - EInl _ t e -> [simprec| EInl ext t *e |] - EInr _ t e -> [simprec| EInr ext t *e |] - ECase _ e a b -> [simprec| ECase ext *e *a *b |] - ENothing _ t -> pure $ ENothing ext t - EJust _ e -> [simprec| EJust ext *e |] - EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] - ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 - ELInl _ t e -> [simprec| ELInl ext t *e |] - ELInr _ t e -> [simprec| ELInr ext t *e |] - ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] - EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild _ n a b -> [simprec| EBuild ext n *a *b |] - EMap _ a b -> [simprec| EMap ext *a *b |] - EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] - ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] - EUnit _ e -> [simprec| EUnit ext *e |] - EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] - EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] - EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] - EReshape _ n a b -> [simprec| EReshape ext n *a *b |] - EZip _ a b -> [simprec| EZip ext *a *b |] - EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] - EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] - EConst _ t v -> pure $ EConst ext t v - EIdx0 _ e -> [simprec| EIdx0 ext *e |] - EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] - EIdx _ a b -> [simprec| EIdx ext *a *b |] - EShape _ e -> [simprec| EShape ext *e |] - EOp _ op e -> [simprec| EOp ext op *e |] - ECustom _ s t p a b c e1 e2 -> do - a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) - b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) - c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) - e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) - e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) - pure (ECustom ext s t p a' b' c' e1' e2') - ERecompute _ e -> [simprec| ERecompute ext *e |] - EWith _ t e1 e2 -> do - e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) - e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) - pure (EWith ext t e1' e2') - -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] - -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] - EZero _ t e -> [simprec| EZero ext t *e |] - EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] - EError _ t s -> pure $ EError ext t s - --- | This can be made more precise by tracking (and not counting) adds on --- locally eliminated accumulators. -hasAdds :: Expr x env t -> Bool -hasAdds = \case - EVar _ _ _ -> False - ELet _ rhs body -> hasAdds rhs || hasAdds body - EPair _ a b -> hasAdds a || hasAdds b - EFst _ e -> hasAdds e - ESnd _ e -> hasAdds e - ENil _ -> False - EInl _ _ e -> hasAdds e - EInr _ _ e -> hasAdds e - ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b - ENothing _ _ -> False - EJust _ e -> hasAdds e - EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e - ELNil _ _ _ -> False - ELInl _ _ e -> hasAdds e - ELInr _ _ e -> hasAdds e - ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c - EConstArr _ _ _ _ -> False - EBuild _ _ a b -> hasAdds a || hasAdds b - EMap _ a b -> hasAdds a || hasAdds b - EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - ESum1Inner _ e -> hasAdds e - EUnit _ e -> hasAdds e - EReplicate1Inner _ a b -> hasAdds a || hasAdds b - EMaximum1Inner _ e -> hasAdds e - EMinimum1Inner _ e -> hasAdds e - EReshape _ _ a b -> hasAdds a || hasAdds b - EZip _ a b -> hasAdds a || hasAdds b - EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e - EConst _ _ _ -> False - EIdx0 _ e -> hasAdds e - EIdx1 _ a b -> hasAdds a || hasAdds b - EIdx _ a b -> hasAdds a || hasAdds b - EShape _ e -> hasAdds e - EOp _ _ e -> hasAdds e - EWith _ _ a b -> hasAdds a || hasAdds b - ERecompute _ e -> hasAdds e - EAccum _ _ _ _ _ _ _ -> True - EZero _ _ e -> hasAdds e - EDeepZero _ _ e -> hasAdds e - EPlus _ _ a b -> hasAdds a || hasAdds b - EOneHot _ _ _ a b -> hasAdds a || hasAdds b - EError _ _ _ -> False - -checkAccumInScope :: SList STy env -> Bool -checkAccumInScope = \case SNil -> False - SCons t env -> check t || checkAccumInScope env - where - check :: STy t -> Bool - check STNil = False - check (STPair s t) = check s || check t - check (STEither s t) = check s || check t - check (STLEither s t) = check s || check t - check (STMaybe t) = check t - check (STArr _ t) = check t - check (STScal _) = False - check STAccum{} = True - -data OneHotTerm dense env a where - OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a -deriving instance Show (OneHotTerm dense env a) - -data InContext f env (a :: Ty) where - InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a - -simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) -simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do - val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val - return $ OneHotTerm dense t prj idx sp val' - -simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) -simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = - unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> - acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> - return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) -simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht - -simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) -simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) - | Just Refl <- isDense (acPrjTy prj1 t1) sp = - let idx2' :: Ex env (AcIdx dense p2 c) - idx2' = case dense of - SAID -> reduceAcIdx t2 prj2 idx2 - SAIS -> idx2 - in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> - acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val -simplifyOHT_concat oht = return oht - --- -- Property not expressed in types: if the Sparse in the input OneHotTerm is --- -- dense, then the Sparse in the output will also be dense. This property is --- -- used when simplifying EOneHot, which cannot represent sparsity. -simplifyOHT :: ActedMonad m => OneHotTerm dense env a - -> m r -- ^ Zero case (onehot is actually zero) - -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) - -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified - -> m r -simplifyOHT oht kzero ktriv k = do - -- traceM $ "sOHT: input " ++ show oht - oht1 <- simplifyOHT_recogniseMonoid oht - -- traceM $ "sOHT: recog " ++ show oht1 - InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 - -- traceM $ "sOHT: unspa " ++ show oht2 - oht3 <- simplifyOHT_concat oht2 - -- traceM $ "sOHT: conca " ++ show oht3 - -- traceM "" - case oht3 of - OneHotTerm _ _ _ _ _ EZero{} -> kzero - OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) - _ -> k (InContext w1 wrap1 oht3) - --- Sets the acted flag whenever a non-trivial projection is returned or the --- output Sparse is different from the input Sparse. -unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' - -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) - -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r -unsparseOneHotD topsp topval k = case (topsp, topval) of - -- eliminate always-Just sparse onehot - (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> - acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k - - -- expand the top levels of a onehot for a sparse type into a onehot for the - -- corresponding non-sparse type - (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> - unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPFst spprj) idx' s1' e' - (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> - unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPSnd spprj) idx' s1' e' - (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> - unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPLeft spprj) idx' s1' e' - (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> - unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPRight spprj) idx' s1' e' - (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> - unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPJust spprj) idx' s1' e' - (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) - | Dict <- styKnown (typeOf idx) -> - unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> - acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' - - -- anything else we don't know how to improve - _ -> k WId id SAPHere (ENil ext) topsp topval - -{- -unsparseOneHotS :: ActedMonad m - => Sparse a a' -> Ex env a' - -> (forall b. Sparse a b -> Ex env b -> m r) -> m r -unsparseOneHotS topsp topval k = case (topsp, topval) of - -- order is relevant to make sure we set the acted flag correctly - (SpAbsent, v@ENil{}) -> k SpAbsent v - (SpAbsent, v@EZero{}) -> k SpAbsent v - (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) - (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) - (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) - - -- the unsparsifying - (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> - acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k - - -- recursion - -- TODO: coproducts could safely become projections as they do not need - -- zeroinfo. But that would only work if the coproduct is at the top, because - -- as soon as we hit a product, we need zeroinfo to make it a projection and - -- we don't have that. - (SpSparse s, e) -> k (SpSparse s) e - (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> - acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) - (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> - unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> - acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') - (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do - case s2 of SpAbsent -> pure () ; _ -> tellActed - k (SpLEither s1' SpAbsent) (ELInl ext STNil e') - (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> - unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do - case s1 of SpAbsent -> pure () ; _ -> tellActed - acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') - (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> - k (SpMaybe s1') (EJust ext e') - (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> - k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') - _ -> _ --} - --- | Recognises 'EZero' and 'EOneHot'. -recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) -recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) -recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = - ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' - (a', b') -> return $ EPair ext a' b' -recogniseMonoid typ@(SMTLEither t1 t2) expr = - case expr of - ELNil{} -> acted $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e - _ -> return expr -recogniseMonoid typ@(SMTMaybe t1) expr = - case expr of - ENothing{} -> acted $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e - _ -> return expr -recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted $ do - e' <- recogniseMonoid t e - return $ - ELet ext e' $ - EOneHot ext typ (SAPArrIdx SAPHere) - (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ)))) - (ENil ext)) - (EVar ext (fromSMTy t) IZ) -recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) - _ -> return e -recogniseMonoid _ e = return e - -reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) -reduceAcIdx topty topprj e = case (topty, topprj) of - (_, SAPHere) -> ENil ext - (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) - (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) - (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e - (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e - (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e - (SMTArr _ t, SAPArrIdx p) -> - eunPair e $ \_ e1 e2 -> - EPair ext (efst e1) (reduceAcIdx t p e2) - -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) -zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) - where - -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) - go t SAPHere _ e = makeZeroInfo t e - go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) - go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) - go SMTLEither{} _ _ _ = ENil ext - go SMTMaybe{} _ _ _ = ENil ext - go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) |
