summaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Simplify.hs
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs138
1 files changed, 83 insertions, 55 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index ea3bb95..228f265 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -19,7 +19,6 @@ import Data.Type.Equality (testEquality)
import AST
import AST.Count
-import CHAD.Types
import Data
@@ -169,35 +168,33 @@ simplify' = \case
(Any True, ENil ext)
(\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
(\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
- EPlus _ _ (EZero _ _) e -> acted $ simplify' e
- EPlus _ _ e (EZero _ _) -> acted $ simplify' e
+ EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
+ EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
e1' <- simplify' e1
e2' <- simplify' e2
simplifyOneHotTerm (OneHotTerm t p e1' e2')
- (Any True, EZero ext t)
+ (Any True, EZero ext t (zeroInfoFromOneHot t p e1 e2))
(\e -> (Any True, e))
(\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
-- type-specific equations for plus
- EPlus _ STNil _ _ -> (Any True, ENil ext)
+ EPlus _ SMTNil _ _ -> (Any True, ENil ext)
- EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
- acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2))
- EPlus _ STPair{} ENothing{} e -> acted $ simplify' e
- EPlus _ STPair{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) ->
+ acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)
- EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) ->
- acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2))
- EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) ->
- acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2))
- EPlus _ STEither{} ENothing{} e -> acted $ simplify' e
- EPlus _ STEither{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) ->
+ acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2)
+ EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) ->
+ acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2)
+ EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e
+ EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e
- EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) ->
+ EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) ->
acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
- EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e
- EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e
+ EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
+ EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
@@ -212,6 +209,10 @@ simplify' = \case
ENothing _ t -> pure $ ENothing ext t
EJust _ e -> EJust ext <$> simplify' e
EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
+ ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t <$> simplify' e
+ ELInr _ t e -> ELInr ext t <$> simplify' e
+ ELCase _ e a b c -> ELCase ext <$> simplify' e <*> simplify' a <*> simplify' b <*> simplify' c
EConstArr _ n t v -> pure $ EConstArr ext n t v
EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c
@@ -233,7 +234,7 @@ simplify' = \case
<*> (let ?accumInScope = False in simplify' c)
<*> simplify' e1 <*> simplify' e2
EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EZero _ t -> pure $ EZero ext t
+ EZero _ t e -> EZero ext t <$> simplify' e
EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
EError _ t s -> pure $ EError ext t s
@@ -266,6 +267,10 @@ hasAdds = \case
ENothing _ _ -> False
EJust _ e -> hasAdds e
EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
+ ELNil _ _ _ -> False
+ ELInl _ _ e -> hasAdds e
+ ELInr _ _ e -> hasAdds e
+ ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c
EConstArr _ _ _ _ -> False
EBuild _ _ a b -> hasAdds a || hasAdds b
EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
@@ -283,7 +288,7 @@ hasAdds = \case
EOp _ _ e -> hasAdds e
EWith _ _ a b -> hasAdds a || hasAdds b
EAccum _ _ _ _ _ _ -> True
- EZero _ _ -> False
+ EZero _ _ e -> hasAdds e
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
@@ -300,17 +305,18 @@ checkAccumInScope = \case SNil -> False
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
+ check (STLEither s t) = check s || check t
data OneHotTerm env p a b where
- OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b
+ OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b
deriving instance Show (OneHotTerm env p a b)
simplifyOneHotTerm :: OneHotTerm env p a b
-> (Any, r) -- ^ Zero case (onehot is actually zero)
- -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
+ -> (Ex env a -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
-> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
-> (Any, r)
-simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero
+simplifyOneHotTerm (OneHotTerm _ _ _ EZero{}) kzero _ _ = kzero
simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
| Just Refl <- testEquality (acPrjTy prj1 t1) t2
@@ -318,57 +324,79 @@ simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero
concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
-simplifyOneHotTerm (OneHotTerm t SAPHere idx e) kzero ktriv k = case (t, e) of
- (STNil, _) -> kzero
+simplifyOneHotTerm (OneHotTerm t SAPHere _ e) kzero ktriv k = case (t, e) of
+ (SMTNil, _) -> kzero
- (STPair{}, ENothing _ _) -> kzero
- (STPair{}, EJust _ (EPair _ e1 EZero{})) ->
- simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) idx e1) kzero ktriv k
- (STPair{}, EJust _ (EPair _ EZero{} e2)) ->
- simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) idx e2) kzero ktriv k
+ (SMTPair{}, EPair _ e1 (EZero _ _ ezi)) ->
+ simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) (EPair ext (ENil ext) ezi) e1) kzero ktriv k
+ (SMTPair{}, EPair _ (EZero _ _ ezi) e2) ->
+ simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) (EPair ext ezi (ENil ext)) e2) kzero ktriv k
- (STEither{}, ENothing _ _) -> kzero
- (STEither{}, EJust _ (EInl _ _ e1)) ->
- simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) idx e1) kzero ktriv k
- (STEither{}, EJust _ (EInr _ _ e2)) ->
- simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) idx e2) kzero ktriv k
+ (SMTLEither{}, ELNil _ _ _) -> kzero
+ (SMTLEither{}, ELInl _ _ e1) ->
+ simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) (ENil ext) e1) kzero ktriv k
+ (SMTLEither{}, ELInr _ _ e2) ->
+ simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) (ENil ext) e2) kzero ktriv k
- (STMaybe{}, ENothing _ _) -> kzero
- (STMaybe{}, EJust _ e1) ->
- simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) idx e1) kzero ktriv k
+ (SMTMaybe{}, ENothing _ _) -> kzero
+ (SMTMaybe{}, EJust _ e1) ->
+ simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) (ENil ext) e1) kzero ktriv k
- (STArr{}, ENothing _ _) -> kzero
-
- (STScal STI32, _) -> kzero
- (STScal STI64, _) -> kzero
- (STScal STF32, EConst _ _ 0.0) -> kzero
- (STScal STF64, EConst _ _ 0.0) -> kzero
- (STScal STBool, _) -> kzero
+ (SMTScal STI32, _) -> kzero
+ (SMTScal STI64, _) -> kzero
+ (SMTScal STF32, EConst _ _ 0.0) -> kzero
+ (SMTScal STF64, EConst _ _ 0.0) -> kzero
_ -> ktriv e
simplifyOneHotTerm term _ _ k = k term
-concatOneHots :: STy a
+concatOneHots :: SMTy a
-> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
-> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
-> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
(_, SAPHere) -> k prj2 idx2
- (STPair a _, SAPFst prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12
- (STPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12
+ (SMTPair a _, SAPFst prj1') ->
+ concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ)))
+ (SMTPair _ b, SAPSnd prj1') ->
+ concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
- (STEither a _, SAPLeft prj1') ->
+ (SMTLEither a _, SAPLeft prj1') ->
concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (STEither _ b, SAPRight prj1') ->
+ (SMTLEither _ b, SAPRight prj1') ->
concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
- (STMaybe a, SAPJust prj1') ->
+ (SMTMaybe a, SAPJust prj1') ->
concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
- (STArr n a, SAPArrIdx prj1' _) ->
+ (SMTArr _ a, SAPArrIdx prj1') ->
concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
+ k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
+
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
+ where
+ -- invariant: AcIdx expression is duplicable
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go t SAPHere _ e = makeZeroInfo t e
+ go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
+ go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
+ go SMTLEither{} _ _ _ = ENil ext
+ go SMTMaybe{} _ _ _ = ENil ext
+ go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)
+
+makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
+ where
+ -- invariant: expression argument is duplicable
+ go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+ go SMTNil _ = ENil ext
+ go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
+ go SMTLEither{} _ = ENil ext
+ go SMTMaybe{} _ = ENil ext
+ go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
+ go SMTScal{} _ = ENil ext