aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Analysis
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Analysis
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Analysis')
-rw-r--r--src/CHAD/Analysis/Identity.hs436
1 files changed, 436 insertions, 0 deletions
diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs
new file mode 100644
index 0000000..212cc7d
--- /dev/null
+++ b/src/CHAD/Analysis/Identity.hs
@@ -0,0 +1,436 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+module CHAD.Analysis.Identity (
+ identityAnalysis,
+ identityAnalysis',
+ ValId(..),
+ validSplitEither,
+) where
+
+import Data.Foldable (toList)
+import Data.List (intercalate)
+
+import CHAD.AST
+import CHAD.AST.Pretty (PrettyX(..))
+import CHAD.Data
+import CHAD.Drev.Types (d1, d2)
+import CHAD.Util.IdGen
+
+
+-- | Every array, scalar and accumulator has an ID. Trivial values such as
+-- Nothing only have the knowledge that they are indeed Nothing. Compound
+-- values know which values they consist of.
+data ValId t where
+ VINil :: ValId TNil
+ VIPair :: ValId a -> ValId b -> ValId (TPair a b)
+ VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
+ VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
+ VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
+ VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
+ VIArr :: Int -> Vec n Int -> ValId (TArr n t)
+ VIScal :: Int -> ValId (TScal t)
+ VIAccum :: Int -> ValId (TAccum t)
+deriving instance Show (ValId t)
+
+instance PrettyX ValId where
+ prettyX = \case
+ VINil -> ""
+ VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")"
+ VIEither (Left a) -> "(L" ++ prettyX a ++ ")"
+ VIEither (Right a) -> "(R" ++ prettyX a ++ ")"
+ VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")"
+ VIMaybe Nothing -> "N"
+ VIMaybe (Just a) -> 'J' : prettyX a
+ VIMaybe' a -> 'M' : prettyX a
+ VILEither (VIMaybe Nothing) -> "lN"
+ VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")"
+ VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))"
+ VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
+ VIScal i -> show i
+ VIAccum i -> 'C' : show i
+
+validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b))
+validSplitEither (VIEither (Left v)) = (Just v, Nothing)
+validSplitEither (VIEither (Right v)) = (Nothing, Just v)
+validSplitEither (VIEither' v1 v2) = (Just v1, Just v2)
+
+-- | Symbolic partial evaluation.
+identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
+identityAnalysis env term = runIdGen 0 $ do
+ env' <- slistMapA genIds env
+ snd <$> idana env' term
+
+identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t
+identityAnalysis' env term = snd (runIdGen 0 (idana env term))
+
+idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
+idana env expr = case expr of
+ EVar _ t i -> do
+ let v = slistIdx env i
+ pure (v, EVar v t i)
+
+ ELet _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (v2, e2') <- idana (v1 `SCons` env) e2
+ pure (v2, ELet v2 e1' e2')
+
+ EPair _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (v2, e2') <- idana env e2
+ pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2')
+
+ EFst _ e -> do
+ (v, e') <- idana env e
+ let VIPair v1 _ = v
+ pure (v1, EFst v1 e')
+
+ ESnd _ e -> do
+ (v, e') <- idana env e
+ let VIPair _ v2 = v
+ pure (v2, ESnd v2 e')
+
+ ENil _ -> pure (VINil, ENil VINil)
+
+ EInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VIEither (Left v1)
+ pure (v, EInl v t2 e1')
+
+ EInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VIEither (Right v2)
+ pure (v, EInr v t1 e2')
+
+ ECase _ e1 e2 e3 -> do
+ let STEither t1 t2 = typeOf e1
+ (v1, e1') <- idana env e1
+ case v1 of
+ VIEither (Left v1') -> do
+ (v2, e2') <- idana (v1' `SCons` env) e2
+ scrap <- genIds t2
+ (_, e3') <- idana (scrap `SCons` env) e3
+ pure (v2, ECase v2 e1' e2' e3')
+ VIEither (Right v1') -> do
+ scrap <- genIds t1
+ (_, e2') <- idana (scrap `SCons` env) e2
+ (v3, e3') <- idana (v1' `SCons` env) e3
+ pure (v3, ECase v3 e1' e2' e3')
+ VIEither' v1'l v1'r -> do
+ (v2, e2') <- idana (v1'l `SCons` env) e2
+ (v3, e3') <- idana (v1'r `SCons` env) e3
+ res <- unify v2 v3
+ pure (res, ECase res e1' e2' e3')
+
+ ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t)
+
+ EJust _ e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VIMaybe (Just v1)
+ pure (v, EJust v e1')
+
+ EMaybe _ e1 e2 e3 -> do
+ let STMaybe t1 = typeOf e3
+ (v3, e3') <- idana env e3
+ case v3 of
+ VIMaybe Nothing -> do
+ (v1, e1') <- idana env e1
+ scrap <- genIds t1
+ (_, e2') <- idana (scrap `SCons` env) e2
+ pure (v1, EMaybe v1 e1' e2' e3')
+ VIMaybe (Just v3j) -> do
+ (v2, e2') <- idana (v3j `SCons` env) e2
+ (_, e1') <- idana env e1
+ pure (v2, EMaybe v2 e1' e2' e3')
+ VIMaybe' v3' -> do
+ (v2, e2') <- idana (v3' `SCons` env) e2
+ (v1, e1') <- idana env e1
+ res <- unify v1 v2
+ pure (res, EMaybe res e1' e2' e3')
+
+ ELNil _ t1 t2 -> do
+ let v = VILEither (VIMaybe Nothing)
+ pure (v, ELNil v t1 t2)
+
+ ELInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VILEither (VIMaybe (Just (VIEither (Left v1))))
+ pure (v, ELInl v t2 e1')
+
+ ELInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VILEither (VIMaybe (Just (VIEither (Right v2))))
+ pure (v, ELInr v t1 e2')
+
+ ELCase _ e1 e2 e3 e4 -> do
+ let STLEither t1 t2 = typeOf e1
+ (v1L, e1') <- idana env e1
+ let VILEither v1 = v1L
+ let go mv1'l mv1'r f = do
+ v1'l <- maybe (genIds t1) pure mv1'l
+ v1'r <- maybe (genIds t2) pure mv1'r
+ (v2, e2') <- idana env e2
+ (v3, e3') <- idana (v1'l `SCons` env) e3
+ (v4, e4') <- idana (v1'r `SCons` env) e4
+ res <- f v2 v3 v4
+ pure (res, ELCase res e1' e2' e3' e4')
+ case v1 of
+ VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2)
+ VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3)
+ VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4)
+ VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4)
+ VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3)
+ VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4)
+ VIMaybe' (VIEither' v1'l v1'r) ->
+ go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4)
+
+ EConstArr _ dim t arr -> do
+ x1 <- VIArr <$> genId <*> vecReplicateA dim genId
+ pure (x1, EConstArr x1 dim t arr)
+
+ EBuild _ dim e1 e2 -> do
+ (shids, e1') <- idana env e1
+ x1 <- genIds (tTup (sreplicate dim tIx))
+ (_, e2') <- idana (x1 `SCons` env) e2
+ res <- VIArr <$> genId <*> shidsToVec dim shids
+ pure (res, EBuild res dim e1' e2')
+
+ EMap _ e1 e2 -> do
+ let STArr _ t = typeOf e2
+ x1 <- genIds t
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (v2, e2') <- idana env e2
+ let VIArr _ sh = v2
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMap res e1' e2')
+
+ EFold1Inner _ cm e1 e2 e3 -> do
+ let t1 = typeOf e1
+ x1 <- genIds (STPair t1 t1)
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (_, e2') <- idana env e2
+ (v3, e3') <- idana env e3
+ let VIArr _ (_ :< sh) = v3
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EFold1Inner res cm e1' e2' e3')
+
+ ESum1Inner _ e1 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ (_ :< sh) = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, ESum1Inner res e1')
+
+ EUnit _ e1 -> do
+ (_, e1') <- idana env e1
+ res <- VIArr <$> genId <*> pure VNil
+ pure (res, EUnit res e1')
+
+ EReplicate1Inner _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ let VIScal v1' = v1
+ (v2, e2') <- idana env e2
+ let VIArr _ sh = v2
+ res <- VIArr <$> genId <*> pure (v1' :< sh)
+ pure (res, EReplicate1Inner res e1' e2')
+
+ EMaximum1Inner _ e1 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ (_ :< sh) = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMaximum1Inner res e1')
+
+ EMinimum1Inner _ e1 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ (_ :< sh) = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMinimum1Inner res e1')
+
+ EReshape _ dim e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- VIArr <$> genId <*> shidsToVec dim v1
+ pure (res, EReshape res dim e1' e2')
+
+ EZip _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ let VIArr _ sh = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EZip res e1' e2')
+
+ EFold1InnerD1 _ cm e1 e2 e3 -> do
+ let t1 = typeOf e2
+ x1 <- genIds (STPair t1 t1)
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (_, e2') <- idana env e2
+ (v3, e3') <- idana env e3
+ let VIArr _ sh'@(_ :< sh) = v3
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh')
+ pure (res, EFold1InnerD1 res cm e1' e2' e3')
+
+ EFold1InnerD2 _ cm ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ xf1 <- genIds t2
+ xf2 <- genIds tB
+ (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef
+ (v2, e2') <- idana env ebog
+ (_, e3') <- idana env ed
+ let VIArr _ sh@(_ :< sh') = v2
+ res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh)
+ pure (res, EFold1InnerD2 res cm e1' e2' e3')
+
+ EConst _ t val -> do
+ res <- VIScal <$> genId
+ pure (res, EConst res t val)
+
+ EIdx0 _ e1 -> do
+ (_, e1') <- idana env e1
+ res <- genIds (typeOf expr)
+ pure (res, EIdx0 res e1')
+
+ EIdx1 _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ let VIArr _ sh = v1
+ (_, e2') <- idana env e2
+ res <- VIArr <$> genId <*> pure (vecInit sh)
+ pure (res, EIdx1 res e1' e2')
+
+ EIdx _ e1 e2 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- genIds (typeOf expr)
+ pure (res, EIdx res e1' e2')
+
+ EShape _ e1 -> do
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
+ let VIArr _ sh = v1
+ res = vecToShids dim sh
+ pure (res, EShape res e1')
+
+ EOp _ (op :: SOp a t) e1 -> do
+ (_, e1') <- idana env e1
+ res <- genIds (typeOf expr)
+ pure (res, EOp res op e1')
+
+ ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do
+ let t4 = typeOf e1
+ x1 <- genIds t2
+ x2 <- genIds t1
+ (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1
+ x3 <- genIds (d1 t2)
+ x4 <- genIds (d1 t1)
+ (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2
+ x5 <- genIds (d2 t4)
+ x6 <- genIds t3
+ (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3
+ (_, e4') <- idana env e4
+ (_, e5') <- idana env e5
+ res <- genIds t4
+ pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
+
+ ERecompute _ e -> do
+ (v, e') <- idana env e
+ pure (v, ERecompute v e')
+
+ EWith _ t e1 e2 -> do
+ let t1 = typeOf e1
+ (_, e1') <- idana env e1
+ x1 <- VIAccum <$> genId
+ (v2, e2') <- idana (x1 `SCons` env) e2
+ x2 <- genIds t1
+ let res = VIPair v2 x2
+ pure (res, EWith res t e1' e2')
+
+ EAccum _ t prj e1 sp e2 e3 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ (_, e3') <- idana env e3
+ pure (VINil, EAccum VINil t prj e1' sp e2' e3')
+
+ EZero _ t e1 -> do
+ -- Approximate the result of EZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EZero res t e1')
+
+ EDeepZero _ t e1 -> do
+ -- Approximate the result of EDeepZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EDeepZero res t e1')
+
+ EPlus _ t e1 e2 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- genIds (fromSMTy t)
+ pure (res, EPlus res t e1' e2')
+
+ EOneHot _ t i e1 e2 -> do
+ (_, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- genIds (fromSMTy t)
+ pure (res, EOneHot res t i e1' e2')
+
+ EError _ t s -> do
+ res <- genIds t
+ pure (res, EError res t s)
+
+-- | This value might be either of the two arguments; we don't know which.
+unify :: ValId t -> ValId t -> IdGen (ValId t)
+unify VINil VINil = pure VINil
+unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d
+unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b
+unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b
+unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b
+unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a
+unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c
+unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c
+unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b
+unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c
+unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d
+unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing
+unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b
+unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a
+unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a
+unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a
+unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
+unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
+unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VILEither a) (VILEither b) = VILEither <$> unify a b
+unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
+unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
+unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
+
+unifyID :: Int -> Int -> IdGen Int
+unifyID i j | i == j = pure i
+ | otherwise = genId
+
+genIds :: STy t -> IdGen (ValId t)
+genIds STNil = pure VINil
+genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
+genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
+genIds (STMaybe t) = VIMaybe' <$> genIds t
+genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
+genIds STScal{} = VIScal <$> genId
+genIds STAccum{} = VIAccum <$> genId
+
+shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
+shidsToVec SZ _ = pure VNil
+shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is
+
+vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx))
+vecToShids SZ VNil = VINil
+vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i)