diff options
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r-- | src/Simplify.hs | 545 |
1 files changed, 422 insertions, 123 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index ac1bb8b..74b6601 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,8 +1,12 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -10,24 +14,31 @@ {-# LANGUAGE TypeOperators #-} module Simplify ( simplifyN, simplifyFix, - SimplifyConfig(..), simplifyWith, simplifyFixWith, + SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where +import Control.Monad (ap) +import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) + +import Debug.Trace import AST import AST.Count -import CHAD.Types +import AST.Pretty +import AST.Sparse.Types +import AST.UnMonoid (acPrjCompose) import Data +import Simplify.TH --- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some. data SimplifyConfig = SimplifyConfig + { scLogging :: Bool + } defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig +defaultSimplifyConfig = SimplifyConfig False simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t simplifyN 0 = id @@ -37,13 +48,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t simplify = let ?accumInScope = checkAccumInScope @env knownEnv ?config = defaultSimplifyConfig - in snd . simplify' + in snd . runSM . simplify' simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t simplifyWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config - in snd . simplify' + in snd . runSM . simplify' simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t simplifyFix = simplifyFixWith defaultSimplifyConfig @@ -53,22 +64,74 @@ simplifyFixWith config = let ?accumInScope = checkAccumInScope @env knownEnv ?config = config in fix $ \loop e -> - let (Any act, e') = simplify' e + let (act, e') = runSM (simplify' e) in if act then loop e' else e' -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) -simplify' = \case +-- | simplify monad +newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) + deriving (Functor) + +instance Applicative (SM tenv tt env t) where + pure x = SM (\_ -> (Any False, x)) + (<*>) = ap + +instance Monad (SM tenv tt env t) where + SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx + +runSM :: SM env t env t a -> (Bool, a) +runSM (SM f) = first getAny (f id) + +smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) +smReconstruct core = SM (\ctx -> (Any False, ctx core)) + +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) + +-- more convenient in practice +acted :: ActedMonad m => m a -> m a +acted m = tellActed >> m + +within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a +within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) + +simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify' expr + | scLogging ?config = do + res <- simplify'Rec expr + full <- smReconstruct res + let printed = ppExpr knownEnv full + replace a bs = concatMap (\x -> if x == a then bs else [x]) + str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed + | otherwise = "--- simplify step: " ++ printed + traceM str + return res + | otherwise = simplify'Rec expr + +simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify'Rec = \case -- inlining ELet _ rhs body | cheapExpr rhs - -> acted $ simplify' (subst1 rhs body) + -> acted $ simplify' (substInline rhs body) | Occ lexOcc runOcc <- occCount IZ body , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not - -> acted $ simplify' (subst1 rhs body) + -> acted $ simplify' (substInline rhs body) - -- let splitting + -- let splitting / let peeling ELet _ (EPair _ a b) body -> acted $ simplify' $ ELet ext a $ @@ -76,13 +139,20 @@ simplify' = \case subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) IS i -> EVar ext t (IS (IS i))) body + ELet _ (EJust _ a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body + ELet _ (EInl _ t2 a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body + ELet _ (EInr _ t1 a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body -- let rotation - ELet _ (ELet _ rhs a) b -> + ELet _ (ELet _ rhs a) b -> do + b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b acted $ simplify' $ ELet ext rhs $ ELet ext a $ - weakenExpr (WCopy WSink) (snd (simplify' b)) + weakenExpr (WCopy WSink) b' -- beta rules for products EFst _ (EPair _ e e') @@ -100,12 +170,20 @@ simplify' = \case EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 - -- let floating to facilitate beta reduction + -- let floating EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) + EAccum _ t p e1 sp (ELet _ rhs body) acc -> + acted $ simplify' $ + ELet ext rhs $ + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) + + -- let () = e in () ~> e + ELet _ e1 (ENil _) | STNil <- typeOf e1 -> + acted $ simplify' e1 -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> @@ -114,89 +192,150 @@ simplify' = \case ESnd _ (ECase _ e1 e2 e3) -> acted $ simplify' $ ECase ext e1 (ESnd ext e2) (ESnd ext e3) + EFst _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (EFst ext e1) (EFst ext e2) e3 + ESnd _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 + + -- TODO: more array indexing + EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) + EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 - -- TODO: array indexing (index of build, index of fold) + -- TODO: more array shape + EShape _ (EBuild _ _ e _) -> acted $ simplify' e - -- TODO: beta rules for maybe + -- TODO: more constant folding + EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) + EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) - -- TODO: constant folding for operations + -- inline cheap array constructors + ELet _ (EReplicate1Inner _ e1 e2) e3 -> + acted $ simplify' $ + ELet ext (EPair ext e1 e2) $ + let v = EVar ext (STPair tIx (typeOf e2)) IZ + in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3 + -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is + -- -- cheap, which it can't be because (!) is not cheap if you do AD after. + -- -- Should do proper SoA representation. + -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 -> + -- acted $ simplify' $ + -- ELet ext e1 $ + -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3 + + -- eta rule for unit + e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> + case e of + ENil _ -> return e + _ -> acted $ return (ENil ext) + + EBuild _ SZ _ e -> + acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules - EAccum _ t p e1 e2 acc -> do - acc' <- simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1 e2) - (Any True, ENil ext) - (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc')) - EPlus _ _ (EZero _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _) -> acted $ simplify' e - EOneHot _ t p e1 e2 -> - simplifyOneHotTerm (OneHotTerm t p e1 e2) - (Any True, EZero ext t) - (\e -> (Any True, e)) - (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2')) + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') + (acted $ return (ENil ext)) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) + EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e + EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e + EOneHot _ t p e1 e2 -> do + e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 + e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') + (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") -- type-specific equations for plus - EPlus _ STNil _ _ -> (Any True, ENil ext) + EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> + acted $ return (ENil ext) - EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> - acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) - EPlus _ STPair{} ENothing{} e -> acted $ simplify' e - EPlus _ STPair{} e ENothing{} -> acted $ simplify' e + EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> + acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) - EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> - acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) - EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> - acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) - EPlus _ STEither{} ENothing{} e -> acted $ simplify' e - EPlus _ STEither{} e ENothing{} -> acted $ simplify' e + EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> + acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) + EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> + acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) + EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e + EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e - EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> + EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e + EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e + EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e -- fallback recursion EVar _ t i -> pure $ EVar ext t i - ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b - EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b - EFst _ e -> EFst ext <$> simplify' e - ESnd _ e -> ESnd ext <$> simplify' e + ELet _ a b -> [simprec| ELet ext *a *b |] + EPair _ a b -> [simprec| EPair ext *a *b |] + EFst _ e -> [simprec| EFst ext *e |] + ESnd _ e -> [simprec| ESnd ext *e |] ENil _ -> pure $ ENil ext - EInl _ t e -> EInl ext t <$> simplify' e - EInr _ t e -> EInr ext t <$> simplify' e - ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b + EInl _ t e -> [simprec| EInl ext t *e |] + EInr _ t e -> [simprec| EInr ext t *e |] + ECase _ e a b -> [simprec| ECase ext *e *a *b |] ENothing _ t -> pure $ ENothing ext t - EJust _ e -> EJust ext <$> simplify' e - EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e + EJust _ e -> [simprec| EJust ext *e |] + EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] + ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 + ELInl _ t e -> [simprec| ELInl ext t *e |] + ELInr _ t e -> [simprec| ELInr ext t *e |] + ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b - EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c - ESum1Inner _ e -> ESum1Inner ext <$> simplify' e - EUnit _ e -> EUnit ext <$> simplify' e - EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b - EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e - EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e + EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] + ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] + EUnit _ e -> [simprec| EUnit ext *e |] + EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] + EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] + EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] EConst _ t v -> pure $ EConst ext t v - EIdx0 _ e -> EIdx0 ext <$> simplify' e - EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b - EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b - EShape _ e -> EShape ext <$> simplify' e - EOp _ op e -> EOp ext op <$> simplify' e - ECustom _ s t p a b c e1 e2 -> - ECustom ext s t p - <$> (let ?accumInScope = False in simplify' a) - <*> (let ?accumInScope = False in simplify' b) - <*> (let ?accumInScope = False in simplify' c) - <*> simplify' e1 <*> simplify' e2 - EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EZero _ t -> pure $ EZero ext t - EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b + EIdx0 _ e -> [simprec| EIdx0 ext *e |] + EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] + EIdx _ a b -> [simprec| EIdx ext *a *b |] + EShape _ e -> [simprec| EShape ext *e |] + EOp _ op e -> [simprec| EOp ext op *e |] + ECustom _ s t p a b c e1 e2 -> do + a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) + b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) + c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) + e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) + e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) + pure (ECustom ext s t p a' b' c' e1' e2') + ERecompute _ e -> [simprec| ERecompute ext *e |] + EWith _ t e1 e2 -> do + e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) + e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) + pure (EWith ext t e1' e2') + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s -acted :: (Any, a) -> (Any, a) -acted (_, x) = (Any True, x) - cheapExpr :: Expr x env t -> Bool cheapExpr = \case EVar{} -> True @@ -204,6 +343,7 @@ cheapExpr = \case EConst{} -> True EFst _ e -> cheapExpr e ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e _ -> False -- | This can be made more precise by tracking (and not counting) adds on @@ -222,9 +362,13 @@ hasAdds = \case ENothing _ _ -> False EJust _ e -> hasAdds e EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e + ELNil _ _ _ -> False + ELInl _ _ e -> hasAdds e + ELInr _ _ e -> hasAdds e + ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b - EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b @@ -238,8 +382,10 @@ hasAdds = \case EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b - EAccum _ _ _ _ _ _ -> True - EZero _ _ -> False + ERecompute _ e -> hasAdds e + EAccum _ _ _ _ _ _ _ -> True + EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -252,49 +398,202 @@ checkAccumInScope = \case SNil -> False check STNil = False check (STPair s t) = check s || check t check (STEither s t) = check s || check t + check (STLEither s t) = check s || check t check (STMaybe t) = check t check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True -data OneHotTerm env p a b where - OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) - -simplifyOneHotTerm :: OneHotTerm env p a b - -> (Any, r) -- ^ Zero case (onehot is actually zero) - -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) - -> (Any, r) -simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - = do (Any True, ()) -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k -simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e -simplifyOneHotTerm term _ _ k = k term - -concatOneHots :: STy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (STPair a _, SAPFst prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 - (STPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 - - (STEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (STEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (STMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - (STArr n a, SAPArrIdx prj1' _) -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} + +-- | Recognises 'EZero' and 'EOneHot'. +recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) +recogniseMonoid _ e@EOneHot{} = return e +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = + ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (a', b') -> return $ EPair ext a' b' +recogniseMonoid typ@(SMTLEither t1 t2) expr = + case expr of + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + _ -> return expr +recogniseMonoid typ@(SMTMaybe t1) expr = + case expr of + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + _ -> return expr +recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = + acted $ do + e' <- recogniseMonoid t e + return $ + ELet ext e' $ + EOneHot ext typ (SAPArrIdx SAPHere) + (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ)))) + (ENil ext)) + (EVar ext (fromSMTy t) IZ) +recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) + _ -> return e +recogniseMonoid _ e = return e + +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) + where + -- invariant: AcIdx expression is duplicable + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) + go t SAPHere _ e = makeZeroInfo t e + go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) + go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) + go SMTLEither{} _ _ _ = ENil ext + go SMTMaybe{} _ _ _ = ENil ext + go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) |