diff options
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r-- | src/Simplify.hs | 106 |
1 files changed, 60 insertions, 46 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index e110206..d3b850f 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -226,19 +226,19 @@ simplify'Rec = \case e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2') (acted $ return (ENil ext)) (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) (\e -> acted $ return e) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -373,27 +373,27 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm env p a b where - OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) +data OneHotTerm dense env p a b where + OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b +deriving instance Show (OneHotTerm dense env p a b) -simplifyOneHotTerm :: OneHotTerm env p a b +simplifyOneHotTerm :: OneHotTerm dense env p a b -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) + -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r) -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do +simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 case val1' of EZero{} -> kzero EOneHot _ t2 prj2 idx2 val2 | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k + concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k _ -> case prj1 of SAPHere -> ktriv val1 - _ -> k (OneHotTerm t1 prj1 idx1 val1) + _ -> k (OneHotTerm dense t1 prj1 idx1 val1) -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) @@ -433,52 +433,66 @@ recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (SMTPair a _, SAPFst prj1') -> - concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> +concatOneHots :: SStillDense dense -> SMTy a + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b) + -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r +concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of + (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2) + (SAI_S, _, SAPHere) -> k prj2 idx2 + + (SAI_D, SMTPair a _, SAPFst prj1') -> + concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> + k (SAPFst prj12) idx12 + (SAI_S, SMTPair a _, SAPFst prj1') -> + concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SMTPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + (SAI_D, SMTPair _ b, SAPSnd prj1') -> + concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> + k (SAPSnd prj12) idx12 + (SAI_S, SMTPair _ b, SAPSnd prj1') -> + concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (SMTLEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (SMTLEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 + (_, SMTLEither a _, SAPLeft prj1') -> + concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 + (_, SMTLEither _ b, SAPRight prj1') -> + concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - (SMTMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 + (_, SMTMaybe a, SAPJust prj1') -> + concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - (SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + -- yes, twice the same code, but we need a concrete denseness indicator to + -- reduce AcIdx (the only difference between the dense and sparse versions is + -- whether there extra info also contains an array shape, and this code + -- handles the extra info uniformly) + (SAI_D, SMTArr _ a, SAPArrIdx prj1') -> + concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + (SAI_S, SMTArr _ a, SAPArrIdx prj1') -> + concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither{}, SAPLeft{}) -> e + (SMTLEither{}, SAPRight{}) -> e + (SMTMaybe{}, SAPJust{}) -> e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext |