diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Analysis | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Analysis')
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs | 436 |
1 files changed, 436 insertions, 0 deletions
diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs new file mode 100644 index 0000000..212cc7d --- /dev/null +++ b/src/CHAD/Analysis/Identity.hs @@ -0,0 +1,436 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +module CHAD.Analysis.Identity ( + identityAnalysis, + identityAnalysis', + ValId(..), + validSplitEither, +) where + +import Data.Foldable (toList) +import Data.List (intercalate) + +import CHAD.AST +import CHAD.AST.Pretty (PrettyX(..)) +import CHAD.Data +import CHAD.Drev.Types (d1, d2) +import CHAD.Util.IdGen + + +-- | Every array, scalar and accumulator has an ID. Trivial values such as +-- Nothing only have the knowledge that they are indeed Nothing. Compound +-- values know which values they consist of. +data ValId t where + VINil :: ValId TNil + VIPair :: ValId a -> ValId b -> ValId (TPair a b) + VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative + VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) + VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) + VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value + VIArr :: Int -> Vec n Int -> ValId (TArr n t) + VIScal :: Int -> ValId (TScal t) + VIAccum :: Int -> ValId (TAccum t) +deriving instance Show (ValId t) + +instance PrettyX ValId where + prettyX = \case + VINil -> "" + VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")" + VIEither (Left a) -> "(L" ++ prettyX a ++ ")" + VIEither (Right a) -> "(R" ++ prettyX a ++ ")" + VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")" + VIMaybe Nothing -> "N" + VIMaybe (Just a) -> 'J' : prettyX a + VIMaybe' a -> 'M' : prettyX a + VILEither (VIMaybe Nothing) -> "lN" + VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" + VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))" + VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" + VIScal i -> show i + VIAccum i -> 'C' : show i + +validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b)) +validSplitEither (VIEither (Left v)) = (Just v, Nothing) +validSplitEither (VIEither (Right v)) = (Nothing, Just v) +validSplitEither (VIEither' v1 v2) = (Just v1, Just v2) + +-- | Symbolic partial evaluation. +identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t +identityAnalysis env term = runIdGen 0 $ do + env' <- slistMapA genIds env + snd <$> idana env' term + +identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t +identityAnalysis' env term = snd (runIdGen 0 (idana env term)) + +idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) +idana env expr = case expr of + EVar _ t i -> do + let v = slistIdx env i + pure (v, EVar v t i) + + ELet _ e1 e2 -> do + (v1, e1') <- idana env e1 + (v2, e2') <- idana (v1 `SCons` env) e2 + pure (v2, ELet v2 e1' e2') + + EPair _ e1 e2 -> do + (v1, e1') <- idana env e1 + (v2, e2') <- idana env e2 + pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2') + + EFst _ e -> do + (v, e') <- idana env e + let VIPair v1 _ = v + pure (v1, EFst v1 e') + + ESnd _ e -> do + (v, e') <- idana env e + let VIPair _ v2 = v + pure (v2, ESnd v2 e') + + ENil _ -> pure (VINil, ENil VINil) + + EInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VIEither (Left v1) + pure (v, EInl v t2 e1') + + EInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VIEither (Right v2) + pure (v, EInr v t1 e2') + + ECase _ e1 e2 e3 -> do + let STEither t1 t2 = typeOf e1 + (v1, e1') <- idana env e1 + case v1 of + VIEither (Left v1') -> do + (v2, e2') <- idana (v1' `SCons` env) e2 + scrap <- genIds t2 + (_, e3') <- idana (scrap `SCons` env) e3 + pure (v2, ECase v2 e1' e2' e3') + VIEither (Right v1') -> do + scrap <- genIds t1 + (_, e2') <- idana (scrap `SCons` env) e2 + (v3, e3') <- idana (v1' `SCons` env) e3 + pure (v3, ECase v3 e1' e2' e3') + VIEither' v1'l v1'r -> do + (v2, e2') <- idana (v1'l `SCons` env) e2 + (v3, e3') <- idana (v1'r `SCons` env) e3 + res <- unify v2 v3 + pure (res, ECase res e1' e2' e3') + + ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t) + + EJust _ e1 -> do + (v1, e1') <- idana env e1 + let v = VIMaybe (Just v1) + pure (v, EJust v e1') + + EMaybe _ e1 e2 e3 -> do + let STMaybe t1 = typeOf e3 + (v3, e3') <- idana env e3 + case v3 of + VIMaybe Nothing -> do + (v1, e1') <- idana env e1 + scrap <- genIds t1 + (_, e2') <- idana (scrap `SCons` env) e2 + pure (v1, EMaybe v1 e1' e2' e3') + VIMaybe (Just v3j) -> do + (v2, e2') <- idana (v3j `SCons` env) e2 + (_, e1') <- idana env e1 + pure (v2, EMaybe v2 e1' e2' e3') + VIMaybe' v3' -> do + (v2, e2') <- idana (v3' `SCons` env) e2 + (v1, e1') <- idana env e1 + res <- unify v1 v2 + pure (res, EMaybe res e1' e2' e3') + + ELNil _ t1 t2 -> do + let v = VILEither (VIMaybe Nothing) + pure (v, ELNil v t1 t2) + + ELInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) + pure (v, ELInl v t2 e1') + + ELInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) + pure (v, ELInr v t1 e2') + + ELCase _ e1 e2 e3 e4 -> do + let STLEither t1 t2 = typeOf e1 + (v1L, e1') <- idana env e1 + let VILEither v1 = v1L + let go mv1'l mv1'r f = do + v1'l <- maybe (genIds t1) pure mv1'l + v1'r <- maybe (genIds t2) pure mv1'r + (v2, e2') <- idana env e2 + (v3, e3') <- idana (v1'l `SCons` env) e3 + (v4, e4') <- idana (v1'r `SCons` env) e4 + res <- f v2 v3 v4 + pure (res, ELCase res e1' e2' e3' e4') + case v1 of + VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) + VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) + VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) + VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) + VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) + VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) + VIMaybe' (VIEither' v1'l v1'r) -> + go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) + + EConstArr _ dim t arr -> do + x1 <- VIArr <$> genId <*> vecReplicateA dim genId + pure (x1, EConstArr x1 dim t arr) + + EBuild _ dim e1 e2 -> do + (shids, e1') <- idana env e1 + x1 <- genIds (tTup (sreplicate dim tIx)) + (_, e2') <- idana (x1 `SCons` env) e2 + res <- VIArr <$> genId <*> shidsToVec dim shids + pure (res, EBuild res dim e1' e2') + + EMap _ e1 e2 -> do + let STArr _ t = typeOf e2 + x1 <- genIds t + (_, e1') <- idana (x1 `SCons` env) e1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure sh + pure (res, EMap res e1' e2') + + EFold1Inner _ cm e1 e2 e3 -> do + let t1 = typeOf e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ (_ :< sh) = v3 + res <- VIArr <$> genId <*> pure sh + pure (res, EFold1Inner res cm e1' e2' e3') + + ESum1Inner _ e1 -> do + (v1, e1') <- idana env e1 + let VIArr _ (_ :< sh) = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, ESum1Inner res e1') + + EUnit _ e1 -> do + (_, e1') <- idana env e1 + res <- VIArr <$> genId <*> pure VNil + pure (res, EUnit res e1') + + EReplicate1Inner _ e1 e2 -> do + (v1, e1') <- idana env e1 + let VIScal v1' = v1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure (v1' :< sh) + pure (res, EReplicate1Inner res e1' e2') + + EMaximum1Inner _ e1 -> do + (v1, e1') <- idana env e1 + let VIArr _ (_ :< sh) = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EMaximum1Inner res e1') + + EMinimum1Inner _ e1 -> do + (v1, e1') <- idana env e1 + let VIArr _ (_ :< sh) = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EMinimum1Inner res e1') + + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + + EZip _ e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + let VIArr _ sh = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EZip res e1' e2') + + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') + + EConst _ t val -> do + res <- VIScal <$> genId + pure (res, EConst res t val) + + EIdx0 _ e1 -> do + (_, e1') <- idana env e1 + res <- genIds (typeOf expr) + pure (res, EIdx0 res e1') + + EIdx1 _ e1 e2 -> do + (v1, e1') <- idana env e1 + let VIArr _ sh = v1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> pure (vecInit sh) + pure (res, EIdx1 res e1' e2') + + EIdx _ e1 e2 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- genIds (typeOf expr) + pure (res, EIdx res e1' e2') + + EShape _ e1 -> do + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 + let VIArr _ sh = v1 + res = vecToShids dim sh + pure (res, EShape res e1') + + EOp _ (op :: SOp a t) e1 -> do + (_, e1') <- idana env e1 + res <- genIds (typeOf expr) + pure (res, EOp res op e1') + + ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do + let t4 = typeOf e1 + x1 <- genIds t2 + x2 <- genIds t1 + (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1 + x3 <- genIds (d1 t2) + x4 <- genIds (d1 t1) + (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2 + x5 <- genIds (d2 t4) + x6 <- genIds t3 + (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3 + (_, e4') <- idana env e4 + (_, e5') <- idana env e5 + res <- genIds t4 + pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') + + ERecompute _ e -> do + (v, e') <- idana env e + pure (v, ERecompute v e') + + EWith _ t e1 e2 -> do + let t1 = typeOf e1 + (_, e1') <- idana env e1 + x1 <- VIAccum <$> genId + (v2, e2') <- idana (x1 `SCons` env) e2 + x2 <- genIds t1 + let res = VIPair v2 x2 + pure (res, EWith res t e1' e2') + + EAccum _ t prj e1 sp e2 e3 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + (_, e3') <- idana env e3 + pure (VINil, EAccum VINil t prj e1' sp e2' e3') + + EZero _ t e1 -> do + -- Approximate the result of EZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EZero res t e1') + + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + + EPlus _ t e1 e2 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- genIds (fromSMTy t) + pure (res, EPlus res t e1' e2') + + EOneHot _ t i e1 e2 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- genIds (fromSMTy t) + pure (res, EOneHot res t i e1' e2') + + EError _ t s -> do + res <- genIds t + pure (res, EError res t s) + +-- | This value might be either of the two arguments; we don't know which. +unify :: ValId t -> ValId t -> IdGen (ValId t) +unify VINil VINil = pure VINil +unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d +unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b +unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b +unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b +unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a +unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c +unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c +unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b +unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c +unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d +unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing +unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b +unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VILEither a) (VILEither b) = VILEither <$> unify a b +unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js +unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j +unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j + +unifyID :: Int -> Int -> IdGen Int +unifyID i j | i == j = pure i + | otherwise = genId + +genIds :: STy t -> IdGen (ValId t) +genIds STNil = pure VINil +genIds (STPair a b) = VIPair <$> genIds a <*> genIds b +genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) +genIds (STMaybe t) = VIMaybe' <$> genIds t +genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId +genIds STScal{} = VIScal <$> genId +genIds STAccum{} = VIAccum <$> genId + +shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) +shidsToVec SZ _ = pure VNil +shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is + +vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx)) +vecToShids SZ VNil = VINil +vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i) |
