summaryrefslogtreecommitdiff
path: root/src/AST/SplitLets.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/AST/SplitLets.hs
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/AST/SplitLets.hs')
-rw-r--r--src/AST/SplitLets.hs26
1 files changed, 17 insertions, 9 deletions
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index dcba1ad..159934d 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -29,6 +29,9 @@ splitLets' = \sub -> \case
EMaybe x a b e ->
let STMaybe t1 = typeOf e
in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
+ ELCase x e a b c ->
+ let STLEither t1 t2 = typeOf e
+ in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
@@ -41,6 +44,9 @@ splitLets' = \sub -> \case
EInr x t e -> EInr x t (splitLets' sub e)
ENothing x t -> ENothing x t
EJust x e -> EJust x (splitLets' sub e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (splitLets' sub e)
+ ELInr x t e -> ELInr x t (splitLets' sub e)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
@@ -57,7 +63,7 @@ splitLets' = \sub -> \case
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
- EZero x t -> EZero x t
+ EZero x t ezi -> EZero x t (splitLets' sub ezi)
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
@@ -121,24 +127,26 @@ split typ = case typ of
STArr{} -> other
STScal{} -> other
STAccum{} -> other
+ STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
splitRec :: forall env t. Ex env t -> STy t
-> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t))
-splitRec rhs = \case
+splitRec rhs typ = case typ of
STNil -> (PNil, BTop)
STPair (a :: STy a) (b :: STy b)
| Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env ->
let (p1, bs1) = splitRec (EFst ext rhs) a
(p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
- t@STEither{} -> other t
- t@STMaybe{} -> other t
- t@STArr{} -> other t
- t@STScal{} -> other t
- t@STAccum{} -> other t
+ STEither{} -> other
+ STMaybe{} -> other
+ STArr{} -> other
+ STScal{} -> other
+ STAccum{} -> other
+ STLEither{} -> other
where
- other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t])
- other t = (Point t IZ, BPush BTop (t, rhs))
+ other :: (Pointers (t : env) t, Bindings Ex env '[t])
+ other = (Point typ IZ, BPush BTop (typ, rhs))