diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 |
commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Analysis/Identity.hs | |
parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/Analysis/Identity.hs')
-rw-r--r-- | src/Analysis/Identity.hs | 59 |
1 files changed, 54 insertions, 5 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index f34bfbc..20575b3 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -30,6 +30,7 @@ data ValId t where VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) @@ -45,6 +46,13 @@ instance PrettyX ValId where VIMaybe Nothing -> "N" VIMaybe (Just a) -> 'J' : prettyX a VIMaybe' a -> 'M' : prettyX a + VILEither (VIMaybe Nothing) -> "lN" + VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" + VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))" VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" VIScal i -> show i VIAccum i -> 'C' : show i @@ -147,6 +155,42 @@ idana env expr = case expr of res <- unify v1 v2 pure (res, EMaybe res e1' e2' e3') + ELNil _ t1 t2 -> do + let v = VILEither (VIMaybe Nothing) + pure (v, ELNil v t1 t2) + + ELInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) + pure (v, ELInl v t2 e1') + + ELInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) + pure (v, ELInr v t1 e2') + + ELCase _ e1 e2 e3 e4 -> do + let STLEither t1 t2 = typeOf e1 + (v1L, e1') <- idana env e1 + let VILEither v1 = v1L + let go mv1'l mv1'r f = do + v1'l <- maybe (genIds t1) pure mv1'l + v1'r <- maybe (genIds t2) pure mv1'r + (v2, e2') <- idana env e2 + (v3, e3') <- idana (v1'l `SCons` env) e3 + (v4, e4') <- idana (v1'r `SCons` env) e4 + res <- f v2 v3 v4 + pure (res, ELCase res e1' e2' e3' e4') + case v1 of + VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) + VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) + VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) + VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) + VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) + VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) + VIMaybe' (VIEither' v1'l v1'r) -> + go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) + EConstArr _ dim t arr -> do x1 <- VIArr <$> genId <*> vecReplicateA dim genId pure (x1, EConstArr x1 dim t arr) @@ -265,20 +309,23 @@ idana env expr = case expr of (_, e3') <- idana env e3 pure (VINil, EAccum VINil t prj e1' e2' e3') - EZero _ t -> do - res <- genIds (d2 t) - pure (res, EZero res t) + EZero _ t e1 -> do + -- Approximate the result of EZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EZero res t e1') EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (d2 t) + res <- genIds (fromSMTy t) pure (res, EPlus res t e1' e2') EOneHot _ t i e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (d2 t) + res <- genIds (fromSMTy t) pure (res, EOneHot res t i e1' e2') EError _ t s -> do @@ -307,6 +354,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VILEither a) (VILEither b) = VILEither <$> unify a b unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j @@ -323,6 +371,7 @@ genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) shidsToVec SZ _ = pure VNil |