summaryrefslogtreecommitdiff
path: root/src/Analysis/Identity.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-27 23:34:59 +0200
commitb1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch)
treea40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Analysis/Identity.hs
parentc750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff)
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of products and arrays. The idea is to make separate monoid types for a "product cotangent" and an "array cotangent" that can be lowered to either a sparse monoid or a non-sparse monoid. Downsides of this approach: lots of API duplication.
Diffstat (limited to 'src/Analysis/Identity.hs')
-rw-r--r--src/Analysis/Identity.hs59
1 files changed, 54 insertions, 5 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index f34bfbc..20575b3 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -30,6 +30,7 @@ data ValId t where
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
@@ -45,6 +46,13 @@ instance PrettyX ValId where
VIMaybe Nothing -> "N"
VIMaybe (Just a) -> 'J' : prettyX a
VIMaybe' a -> 'M' : prettyX a
+ VILEither (VIMaybe Nothing) -> "lN"
+ VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")"
+ VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))"
VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
VIScal i -> show i
VIAccum i -> 'C' : show i
@@ -147,6 +155,42 @@ idana env expr = case expr of
res <- unify v1 v2
pure (res, EMaybe res e1' e2' e3')
+ ELNil _ t1 t2 -> do
+ let v = VILEither (VIMaybe Nothing)
+ pure (v, ELNil v t1 t2)
+
+ ELInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VILEither (VIMaybe (Just (VIEither (Left v1))))
+ pure (v, ELInl v t2 e1')
+
+ ELInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VILEither (VIMaybe (Just (VIEither (Right v2))))
+ pure (v, ELInr v t1 e2')
+
+ ELCase _ e1 e2 e3 e4 -> do
+ let STLEither t1 t2 = typeOf e1
+ (v1L, e1') <- idana env e1
+ let VILEither v1 = v1L
+ let go mv1'l mv1'r f = do
+ v1'l <- maybe (genIds t1) pure mv1'l
+ v1'r <- maybe (genIds t2) pure mv1'r
+ (v2, e2') <- idana env e2
+ (v3, e3') <- idana (v1'l `SCons` env) e3
+ (v4, e4') <- idana (v1'r `SCons` env) e4
+ res <- f v2 v3 v4
+ pure (res, ELCase res e1' e2' e3' e4')
+ case v1 of
+ VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2)
+ VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3)
+ VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4)
+ VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4)
+ VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3)
+ VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4)
+ VIMaybe' (VIEither' v1'l v1'r) ->
+ go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4)
+
EConstArr _ dim t arr -> do
x1 <- VIArr <$> genId <*> vecReplicateA dim genId
pure (x1, EConstArr x1 dim t arr)
@@ -265,20 +309,23 @@ idana env expr = case expr of
(_, e3') <- idana env e3
pure (VINil, EAccum VINil t prj e1' e2' e3')
- EZero _ t -> do
- res <- genIds (d2 t)
- pure (res, EZero res t)
+ EZero _ t e1 -> do
+ -- Approximate the result of EZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EZero res t e1')
EPlus _ t e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (d2 t)
+ res <- genIds (fromSMTy t)
pure (res, EPlus res t e1' e2')
EOneHot _ t i e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (d2 t)
+ res <- genIds (fromSMTy t)
pure (res, EOneHot res t i e1' e2')
EError _ t s -> do
@@ -307,6 +354,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VILEither a) (VILEither b) = VILEither <$> unify a b
unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
@@ -323,6 +371,7 @@ genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
genIds STAccum{} = VIAccum <$> genId
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
shidsToVec SZ _ = pure VNil