summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs106
1 files changed, 60 insertions, 46 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index e110206..d3b850f 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -226,19 +226,19 @@ simplify'Rec = \case
e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1
e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2
acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc
- simplifyOneHotTerm (OneHotTerm t p e1' e2')
+ simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2')
(acted $ return (ENil ext))
(\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
+ (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
- simplifyOneHotTerm (OneHotTerm t p e1' e2')
+ simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2')
(acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
(\e -> acted $ return e)
- (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
+ (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
-- type-specific equations for plus
EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
@@ -373,27 +373,27 @@ checkAccumInScope = \case SNil -> False
check (STScal _) = False
check STAccum{} = True
-data OneHotTerm env p a b where
- OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b
-deriving instance Show (OneHotTerm env p a b)
+data OneHotTerm dense env p a b where
+ OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b
+deriving instance Show (OneHotTerm dense env p a b)
-simplifyOneHotTerm :: OneHotTerm env p a b
+simplifyOneHotTerm :: OneHotTerm dense env p a b
-> SM tenv tt env t r -- ^ Zero case (onehot is actually zero)
-> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r)
+ -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r)
-> SM tenv tt env t r
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do
+simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do
val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1
case val1' of
EZero{} -> kzero
EOneHot _ t2 prj2 idx2 val2
| Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do
tellActed -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k
+ concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k
_ -> case prj1 of
SAPHere -> ktriv val1
- _ -> k (OneHotTerm t1 prj1 idx1 val1)
+ _ -> k (OneHotTerm dense t1 prj1 idx1 val1)
-- | Recognises 'EZero' and 'EOneHot'.
recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
@@ -433,52 +433,66 @@ recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
_ -> return e
recogniseMonoid _ e = return e
-concatOneHots :: SMTy a
- -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
-concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
- (_, SAPHere) -> k prj2 idx2
-
- (SMTPair a _, SAPFst prj1') ->
- concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+concatOneHots :: SStillDense dense -> SMTy a
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b)
+ -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r
+concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of
+ (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2)
+ (SAI_S, _, SAPHere) -> k prj2 idx2
+
+ (SAI_D, SMTPair a _, SAPFst prj1') ->
+ concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 ->
+ k (SAPFst prj12) idx12
+ (SAI_S, SMTPair a _, SAPFst prj1') ->
+ concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ)))
- (SMTPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ (SAI_D, SMTPair _ b, SAPSnd prj1') ->
+ concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 ->
+ k (SAPSnd prj12) idx12
+ (SAI_S, SMTPair _ b, SAPSnd prj1') ->
+ concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
- (SMTLEither a _, SAPLeft prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (SMTLEither _ b, SAPRight prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
+ (_, SMTLEither a _, SAPLeft prj1') ->
+ concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
+ (_, SMTLEither _ b, SAPRight prj1') ->
+ concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
- (SMTMaybe a, SAPJust prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
+ (_, SMTMaybe a, SAPJust prj1') ->
+ concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
- (SMTArr _ a, SAPArrIdx prj1') ->
- concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ -- yes, twice the same code, but we need a concrete denseness indicator to
+ -- reduce AcIdx (the only difference between the dense and sparse versions is
+ -- whether there extra info also contains an array shape, and this code
+ -- handles the extra info uniformly)
+ (SAI_D, SMTArr _ a, SAPArrIdx prj1') ->
+ concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
+ (SAI_S, SMTArr _ a, SAPArrIdx prj1') ->
+ concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
-zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a)
+reduceAcIdx topty topprj e = case (topty, topprj) of
+ (_, SAPHere) -> ENil ext
+ (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
+ (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
+ (SMTLEither{}, SAPLeft{}) -> e
+ (SMTLEither{}, SAPRight{}) -> e
+ (SMTMaybe{}, SAPJust{}) -> e
+ (SMTArr _ t, SAPArrIdx p) ->
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (efst e1) (reduceAcIdx t p e2)
+
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
where
-- invariant: AcIdx expression is duplicable
- go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
go t SAPHere _ e = makeZeroInfo t e
go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
go SMTLEither{} _ _ _ = ENil ext
go SMTMaybe{} _ _ _ = ENil ext
go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)
-
-makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
-makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
- where
- -- invariant: expression argument is duplicable
- go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
- go SMTNil _ = ENil ext
- go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
- go SMTLEither{} _ = ENil ext
- go SMTMaybe{} _ = ENil ext
- go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
- go SMTScal{} _ = ENil ext