diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Simplify.hs | |
parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r-- | src/Simplify.hs | 138 |
1 files changed, 83 insertions, 55 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index ea3bb95..228f265 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -19,7 +19,6 @@ import Data.Type.Equality (testEquality) import AST import AST.Count -import CHAD.Types import Data @@ -169,35 +168,33 @@ simplify' = \case (Any True, ENil ext) (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) - EPlus _ _ (EZero _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _) -> acted $ simplify' e + EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e + EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- simplify' e1 e2' <- simplify' e2 simplifyOneHotTerm (OneHotTerm t p e1' e2') - (Any True, EZero ext t) + (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2)) (\e -> (Any True, e)) (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus - EPlus _ STNil _ _ -> (Any True, ENil ext) + EPlus _ SMTNil _ _ -> (Any True, ENil ext) - EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> - acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) - EPlus _ STPair{} ENothing{} e -> acted $ simplify' e - EPlus _ STPair{} e ENothing{} -> acted $ simplify' e + EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> + acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) - EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> - acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) - EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> - acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) - EPlus _ STEither{} ENothing{} e -> acted $ simplify' e - EPlus _ STEither{} e ENothing{} -> acted $ simplify' e + EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> + acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) + EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> + acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) + EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e + EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e - EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> + EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e + EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e + EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e -- fallback recursion EVar _ t i -> pure $ EVar ext t i @@ -212,6 +209,10 @@ simplify' = \case ENothing _ t -> pure $ ENothing ext t EJust _ e -> EJust ext <$> simplify' e EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e + ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t <$> simplify' e + ELInr _ t e -> ELInr ext t <$> simplify' e + ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c @@ -233,7 +234,7 @@ simplify' = \case <*> (let ?accumInScope = False in simplify' c) <*> simplify' e1 <*> simplify' e2 EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EZero _ t -> pure $ EZero ext t + EZero _ t e -> EZero ext t <$> simplify' e EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b EError _ t s -> pure $ EError ext t s @@ -266,6 +267,10 @@ hasAdds = \case ENothing _ _ -> False EJust _ e -> hasAdds e EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e + ELNil _ _ _ -> False + ELInl _ _ e -> hasAdds e + ELInr _ _ e -> hasAdds e + ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c @@ -283,7 +288,7 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b EAccum _ _ _ _ _ _ -> True - EZero _ _ -> False + EZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -300,17 +305,18 @@ checkAccumInScope = \case SNil -> False check (STArr _ t) = check t check (STScal _) = False check STAccum{} = True + check (STLEither s t) = check s || check t data OneHotTerm env p a b where - OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b + OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b deriving instance Show (OneHotTerm env p a b) simplifyOneHotTerm :: OneHotTerm env p a b -> (Any, r) -- ^ Zero case (onehot is actually zero) - -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) + -> (Ex env a -> (Any, r)) -- ^ Trivial case (no zeros in onehot) -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) -> (Any, r) -simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero +simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k | Just Refl <- testEquality (acPrjTy prj1 t1) t2 @@ -318,57 +324,79 @@ simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k -simplifyOneHotTerm (OneHotTerm t SAPHere idx e) kzero ktriv k = case (t, e) of - (STNil, _) -> kzero +simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of + (SMTNil, _) -> kzero - (STPair{}, ENothing _ _) -> kzero - (STPair{}, EJust _ (EPair _ e1 EZero{})) -> - simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) idx e1) kzero ktriv k - (STPair{}, EJust _ (EPair _ EZero{} e2)) -> - simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) idx e2) kzero ktriv k + (SMTPair{}, EPair _ e1 (EZero _ _ ezi)) -> + simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) (EPair ext (ENil ext) ezi) e1) kzero ktriv k + (SMTPair{}, EPair _ (EZero _ _ ezi) e2) -> + simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) (EPair ext ezi (ENil ext)) e2) kzero ktriv k - (STEither{}, ENothing _ _) -> kzero - (STEither{}, EJust _ (EInl _ _ e1)) -> - simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) idx e1) kzero ktriv k - (STEither{}, EJust _ (EInr _ _ e2)) -> - simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) idx e2) kzero ktriv k + (SMTLEither{}, ELNil _ _ _) -> kzero + (SMTLEither{}, ELInl _ _ e1) -> + simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) (ENil ext) e1) kzero ktriv k + (SMTLEither{}, ELInr _ _ e2) -> + simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) (ENil ext) e2) kzero ktriv k - (STMaybe{}, ENothing _ _) -> kzero - (STMaybe{}, EJust _ e1) -> - simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) idx e1) kzero ktriv k + (SMTMaybe{}, ENothing _ _) -> kzero + (SMTMaybe{}, EJust _ e1) -> + simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) (ENil ext) e1) kzero ktriv k - (STArr{}, ENothing _ _) -> kzero - - (STScal STI32, _) -> kzero - (STScal STI64, _) -> kzero - (STScal STF32, EConst _ _ 0.0) -> kzero - (STScal STF64, EConst _ _ 0.0) -> kzero - (STScal STBool, _) -> kzero + (SMTScal STI32, _) -> kzero + (SMTScal STI64, _) -> kzero + (SMTScal STF32, EConst _ _ 0.0) -> kzero + (SMTScal STF64, EConst _ _ 0.0) -> kzero _ -> ktriv e simplifyOneHotTerm term _ _ k = k term -concatOneHots :: STy a +concatOneHots :: SMTy a -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of (_, SAPHere) -> k prj2 idx2 - (STPair a _, SAPFst prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 - (STPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 + (SMTPair a _, SAPFst prj1') -> + concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) + (SMTPair _ b, SAPSnd prj1') -> + concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (STEither a _, SAPLeft prj1') -> + (SMTLEither a _, SAPLeft prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (STEither _ b, SAPRight prj1') -> + (SMTLEither _ b, SAPRight prj1') -> concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - (STMaybe a, SAPJust prj1') -> + (SMTMaybe a, SAPJust prj1') -> concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - (STArr n a, SAPArrIdx prj1' _) -> + (SMTArr _ a, SAPArrIdx prj1') -> concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) + where + -- invariant: AcIdx expression is duplicable + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go t SAPHere _ e = makeZeroInfo t e + go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) + go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) + go SMTLEither{} _ _ _ = ENil ext + go SMTMaybe{} _ _ _ = ENil ext + go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext |