summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-25 13:35:57 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-25 13:36:05 +0200
commitc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (patch)
tree0c3d2afaa281556ab8e4066ffdd37a3b5abc9a0c /src
parent2da201faba6aeba2bf35d220a0e970ac4fa1768e (diff)
simplify: Additional rules inspired by Example.neuralHEADmaster
Diffstat (limited to 'src')
-rw-r--r--src/Simplify.hs32
1 files changed, 29 insertions, 3 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index e0ab37b..ea3bb95 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -129,11 +129,36 @@ simplify' = \case
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
- -- TODO: array indexing (index of build, index of fold)
+ -- TODO: more array indexing
+ EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
+ EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
- -- TODO: beta rules for maybe
+ -- TODO: more constant folding
+ EOp _ OIf (EConst _ STBool True) -> (Any True, EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> (Any True, EInr ext STNil (ENil ext))
- -- TODO: constant folding for operations
+ -- inline cheap array constructors
+ ELet _ (EReplicate1Inner _ e1 e2) e3 ->
+ acted $ simplify' $
+ ELet ext (EPair ext e1 e2) $
+ let v = EVar ext (STPair tIx (typeOf e2)) IZ
+ in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3
+ -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is
+ -- -- cheap, which it can't be because (!) is not cheap if you do AD after.
+ -- -- Should do proper SoA representation.
+ -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 ->
+ -- acted $ simplify' $
+ -- ELet ext e1 $
+ -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3
+
+ -- eta rule for unit
+ e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
+ case e of
+ ENil _ -> (Any False, e)
+ _ -> (Any True, ENil ext)
+
+ EBuild _ SZ _ e ->
+ acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
EAccum _ t p e1 e2 acc -> do
@@ -222,6 +247,7 @@ cheapExpr = \case
EConst{} -> True
EFst _ e -> cheapExpr e
ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
_ -> False
-- | This can be made more precise by tracking (and not counting) adds on