diff options
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r-- | src/Simplify.hs | 315 |
1 files changed, 215 insertions, 100 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index e110206..74b6601 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE QuasiQuotes #-} @@ -19,13 +21,14 @@ import Control.Monad (ap) import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) import Debug.Trace import AST import AST.Count import AST.Pretty +import AST.Sparse.Types +import AST.UnMonoid (acPrjCompose) import Data import Simplify.TH @@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id) smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) smReconstruct core = SM (\ctx -> (Any False, ctx core)) -tellActed :: SM tenv tt env t () -tellActed = SM (\_ -> (Any True, ())) +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) -- more convenient in practice -acted :: SM tenv tt env t a -> SM tenv tt env t a +acted :: ActedMonad m => m a -> m a acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) -acted' :: (Any, a) -> (Any, a) -acted' (_, x) = (Any True, x) - -liftActed :: (Any, a) -> SM tenv tt env t a -liftActed pair = SM $ \_ -> pair - simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do @@ -167,10 +176,10 @@ simplify'Rec = \case ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 (ELet _ rhs body) acc -> + EAccum _ t p e1 sp (ELet _ rhs body) acc -> acted $ simplify' $ ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) -- let () = e in () ~> e ELet _ e1 (ENil _) | STNil <- typeOf e1 -> @@ -194,6 +203,9 @@ simplify'Rec = \case EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 + -- TODO: more array shape + EShape _ (EBuild _ _ e _) -> acted $ simplify' e + -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) @@ -222,23 +234,40 @@ simplify'Rec = \case acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules - EAccum _ t p e1 e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') (acted $ return (ENil ext)) - (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\e -> acted $ return e) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -302,8 +331,9 @@ simplify'Rec = \case e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') - EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s cheapExpr :: Expr x env t -> Bool @@ -353,8 +383,9 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b ERecompute _ e -> hasAdds e - EAccum _ _ _ _ _ _ -> True + EAccum _ _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -373,51 +404,161 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm env p a b where - OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) - -simplifyOneHotTerm :: OneHotTerm env p a b - -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) - -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) - -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do - val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 - case val1' of - EZero{} -> kzero - EOneHot _ t2 prj2 idx2 val2 - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do - tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k - _ -> case prj1 of - SAPHere -> ktriv val1 - _ -> k (OneHotTerm t1 prj1 idx1 val1) +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' (a', b') -> return $ EPair ext a' b' recogniseMonoid typ@(SMTLEither t1 t2) expr = case expr of - ELNil{} -> acted' $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e _ -> return expr recogniseMonoid typ@(SMTMaybe t1) expr = case expr of - ENothing{} -> acted' $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e _ -> return expr recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted' $ do + acted $ do e' <- recogniseMonoid t e return $ ELet ext e' $ @@ -426,59 +567,33 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = (ENil ext)) (EVar ext (fromSMTy t) IZ) recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (SMTPair a _, SAPFst prj1') -> - concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SMTPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - - (SMTLEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (SMTLEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (SMTMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - (SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext |