summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST/Types.hs1
-rw-r--r--src/Simplify.hs93
2 files changed, 56 insertions, 38 deletions
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index c8515fc..efb1e04 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -73,7 +73,6 @@ type SMTy :: Ty -> Type
data SMTy t where
SMTNil :: SMTy TNil
SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b)
- -- TODO: call this SMTLEither
SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b)
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 140e673..f5eb0a1 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -91,6 +91,12 @@ acted m = tellActed >> m
within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
+acted' :: (Any, a) -> (Any, a)
+acted' (_, x) = (Any True, x)
+
+liftActed :: (Any, a) -> SM tenv tt env t a
+liftActed pair = SM $ \_ -> pair
+
simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
simplify' expr
| scLogging ?config = do
@@ -368,43 +374,56 @@ simplifyOneHotTerm :: OneHotTerm env p a b
-> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
-> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r)
-> SM tenv tt env t r
-simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero
-
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2
- = do tellActed -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
-
--- TODO: This does not actually recurse unless it just so happens to contain
--- another EZero or EOnehot in the final position. Should match on something
--- more general than SAPHere here.
-simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of
- (SMTNil, _) -> kzero
-
- (SMTPair{}, EPair _ e1 (EZero _ _ ezi)) ->
- simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) (EPair ext (ENil ext) ezi) e1) kzero ktriv k
- (SMTPair{}, EPair _ (EZero _ _ ezi) e2) ->
- simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) (EPair ext ezi (ENil ext)) e2) kzero ktriv k
-
- (SMTLEither{}, ELNil _ _ _) -> kzero
- (SMTLEither{}, ELInl _ _ e1) ->
- simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) (ENil ext) e1) kzero ktriv k
- (SMTLEither{}, ELInr _ _ e2) ->
- simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) (ENil ext) e2) kzero ktriv k
-
- (SMTMaybe{}, ENothing _ _) -> kzero
- (SMTMaybe{}, EJust _ e1) ->
- simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) (ENil ext) e1) kzero ktriv k
-
- (SMTScal STI32, _) -> kzero
- (SMTScal STI64, _) -> kzero
- (SMTScal STF32, EConst _ _ 0.0) -> kzero
- (SMTScal STF64, EConst _ _ 0.0) -> kzero
-
- _ -> ktriv e
-
-simplifyOneHotTerm term _ _ k = k term
+simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do
+ val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1
+ case val1' of
+ EZero{} -> kzero
+ EOneHot _ t2 prj2 idx2 val2
+ | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do
+ tellActed -- record, whatever happens later, that we've modified something
+ concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k
+ _ -> case prj1 of
+ SAPHere -> ktriv val1
+ _ -> k (OneHotTerm t1 prj1 idx1 val1)
+
+-- | Recognises 'EZero' and 'EOneHot'.
+recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
+recogniseMonoid _ e@EOneHot{} = return e
+recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext)
+recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
+ ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
+ (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2)
+ (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
+ (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
+ (a', b') -> return $ EPair ext a' b'
+recogniseMonoid typ@(SMTLEither t1 t2) expr =
+ case expr of
+ ELNil{} -> acted' $ return $ EZero ext typ (ENil ext)
+ ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
+ _ -> return expr
+recogniseMonoid typ@(SMTMaybe t1) expr =
+ case expr of
+ ENothing{} -> acted' $ return $ EZero ext typ (ENil ext)
+ EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ _ -> return expr
+recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
+ acted' $ do
+ e' <- recogniseMonoid t e
+ return $
+ ELet ext e' $
+ EOneHot ext typ (SAPArrIdx SAPHere)
+ (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ))))
+ (ENil ext))
+ (EVar ext (fromSMTy t) IZ)
+recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
+ (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext)
+ (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext)
+ (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext)
+ (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext)
+ _ -> return e
+recogniseMonoid _ e = return e
concatOneHots :: SMTy a
-> SAcPrj p1 a b -> Ex env (AcIdx p1 a)