diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-27 23:34:59 +0200 | 
| commit | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (patch) | |
| tree | a40c16fd082bbe4183e7b4194b8cea1408cec379 /src/Analysis | |
| parent | c750f8f9f1275d49ff74297e6648e1bfc1c6d918 (diff) | |
WIP revamp accumulators again: explicit monoid types
No more D2 in accumulators! Paving the way for configurable sparsity of
products and arrays. The idea is to make separate monoid types for a
"product cotangent" and an "array cotangent" that can be lowered to
either a sparse monoid or a non-sparse monoid. Downsides of this
approach: lots of API duplication.
Diffstat (limited to 'src/Analysis')
| -rw-r--r-- | src/Analysis/Identity.hs | 59 | 
1 files changed, 54 insertions, 5 deletions
| diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index f34bfbc..20575b3 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -30,6 +30,7 @@ data ValId t where    VIEither' :: ValId a -> ValId b -> ValId (TEither a b)  -- ^ unknown alternative, but known values in each case    VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)    VIMaybe' :: ValId a -> ValId (TMaybe a)  -- ^ if it's Just, it contains this value +  VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)    VIArr :: Int -> Vec n Int -> ValId (TArr n t)    VIScal :: Int -> ValId (TScal t)    VIAccum :: Int -> ValId (TAccum t) @@ -45,6 +46,13 @@ instance PrettyX ValId where      VIMaybe Nothing -> "N"      VIMaybe (Just a) -> 'J' : prettyX a      VIMaybe' a -> 'M' : prettyX a +    VILEither (VIMaybe Nothing) -> "lN" +    VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" +    VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" +    VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" +    VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" +    VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" +    VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))"      VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"      VIScal i -> show i      VIAccum i -> 'C' : show i @@ -147,6 +155,42 @@ idana env expr = case expr of          res <- unify v1 v2          pure (res, EMaybe res e1' e2' e3') +  ELNil _ t1 t2 -> do +    let v = VILEither (VIMaybe Nothing) +    pure (v, ELNil v t1 t2) + +  ELInl _ t2 e1 -> do +    (v1, e1') <- idana env e1 +    let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) +    pure (v, ELInl v t2 e1') + +  ELInr _ t1 e2 -> do +    (v2, e2') <- idana env e2 +    let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) +    pure (v, ELInr v t1 e2') + +  ELCase _ e1 e2 e3 e4 -> do +    let STLEither t1 t2 = typeOf e1 +    (v1L, e1') <- idana env e1 +    let VILEither v1 = v1L +    let go mv1'l mv1'r f = do +          v1'l <- maybe (genIds t1) pure mv1'l +          v1'r <- maybe (genIds t2) pure mv1'r +          (v2, e2') <- idana env e2 +          (v3, e3') <- idana (v1'l `SCons` env) e3 +          (v4, e4') <- idana (v1'r `SCons` env) e4 +          res <- f v2 v3 v4 +          pure (res, ELCase res e1' e2' e3' e4') +    case v1 of +      VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) +      VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) +      VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) +      VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) +      VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) +      VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) +      VIMaybe' (VIEither' v1'l v1'r) -> +        go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) +    EConstArr _ dim t arr -> do      x1 <- VIArr <$> genId <*> vecReplicateA dim genId      pure (x1, EConstArr x1 dim t arr) @@ -265,20 +309,23 @@ idana env expr = case expr of      (_, e3') <- idana env e3      pure (VINil, EAccum VINil t prj e1' e2' e3') -  EZero _ t -> do -    res <- genIds (d2 t) -    pure (res, EZero res t) +  EZero _ t e1 -> do +    -- Approximate the result of EZero to be independent from the zero info +    -- expression; not quite true for shape variables +    (_, e1') <- idana env e1 +    res <- genIds (fromSMTy t) +    pure (res, EZero res t e1')    EPlus _ t e1 e2 -> do      (_, e1') <- idana env e1      (_, e2') <- idana env e2 -    res <- genIds (d2 t) +    res <- genIds (fromSMTy t)      pure (res, EPlus res t e1' e2')    EOneHot _ t i e1 e2 -> do      (_, e1') <- idana env e1      (_, e2') <- idana env e2 -    res <- genIds (d2 t) +    res <- genIds (fromSMTy t)      pure (res, EOneHot res t i e1' e2')    EError _ t s -> do @@ -307,6 +354,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b  unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a  unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b  unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VILEither a) (VILEither b) = VILEither <$> unify a b  unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js  unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j  unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j @@ -323,6 +371,7 @@ genIds (STMaybe t) = VIMaybe' <$> genIds t  genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId  genIds STScal{} = VIScal <$> genId  genIds STAccum{} = VIAccum <$> genId +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)  shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)  shidsToVec SZ _ = pure VNil | 
