aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Simplify.hs')
-rw-r--r--src/CHAD/Simplify.hs10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs
index ea253d6..bbc2db8 100644
--- a/src/CHAD/Simplify.hs
+++ b/src/CHAD/Simplify.hs
@@ -12,6 +12,8 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=50 #-}
module CHAD.Simplify (
simplifyN, simplifyFix,
SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
@@ -307,6 +309,9 @@ simplify'Rec = \case
EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
+ -- user types
+ EUnUser _ (EUser _ _ e) -> acted $ simplify' e
+
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
ELet _ a b -> [simprec| ELet ext *a *b |]
@@ -361,6 +366,8 @@ simplify'Rec = \case
EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
+ EUser _ t e -> [simprec| EUser ext t *e |]
+ EUnUser _ e -> [simprec| EUnUser ext *e |]
-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
@@ -410,6 +417,8 @@ hasAdds = \case
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
+ EUser _ _ e -> hasAdds e
+ EUnUser _ e -> hasAdds e
checkAccumInScope :: SList STy env -> Bool
checkAccumInScope = \case SNil -> False
@@ -424,6 +433,7 @@ checkAccumInScope = \case SNil -> False
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
+ check (STUser t) = check (userRepTy t)
data OneHotTerm dense env a where
OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a