diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-01-28 16:58:51 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-01-28 16:58:51 +0100 |
commit | 3e04b03acd5e7138e0f6241133585f22ddb73060 (patch) | |
tree | 57b60cf7a784e3e1ece6c05afecff52eb4beb6db /src/Analysis/Identity.hs | |
parent | 817cd3c75a2bbbbb355ac33fc7ca3ad8a16bdc92 (diff) |
Pretty-printer that supports extension fields
Diffstat (limited to 'src/Analysis/Identity.hs')
-rw-r--r-- | src/Analysis/Identity.hs | 184 |
1 files changed, 120 insertions, 64 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 9087143..285cfb8 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -1,11 +1,14 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} module Analysis.Identity ( identityAnalysis, ) where import AST +import AST.Pretty (PrettyX(..)) +import CHAD.Types (d1, d2) import Data import Util.IdGen @@ -19,13 +22,14 @@ data ValId t where VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) + VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value VIArr :: Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) -- | We don't know what this consists of, but it's a value, and let's just -- give it an ID nevertheless. - VIThing :: Int -> ValId t + VIThing :: STy t -> Int -> ValId t instance Eq (ValId t) where VINil == VINil = True @@ -38,33 +42,40 @@ instance Eq (ValId t) where VIEither'{} == _ = False VIMaybe a == VIMaybe a' = a == a' VIMaybe{} == _ = False + VIMaybe' a == VIMaybe' a' = a == a' + VIMaybe'{} == _ = False VIArr i == VIArr i' = i == i' VIArr{} == _ = False VIScal i == VIScal i' = i == i' VIScal{} == _ = False VIAccum i == VIAccum i' = i == i' VIAccum{} == _ = False - VIThing i == VIThing i' = i == i' + VIThing _ i == VIThing _ i' = i == i' VIThing{} == _ = False +instance PrettyX ValId where + prettyX = \case + VINil -> "" + VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")" + VIEither (Left a) -> "(L" ++ prettyX a ++ ")" + VIEither (Right a) -> "(R" ++ prettyX a ++ ")" + VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")" + VIMaybe Nothing -> "N" + VIMaybe (Just a) -> 'J' : prettyX a + VIMaybe' a -> 'M' : prettyX a + VIArr i -> 'A' : show i + VIScal i -> show i + VIAccum i -> 'C' : show i + VIThing _ i -> '{' : show i ++ "}" + -- | Symbolic partial evaluation. identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t identityAnalysis env term = runIdGen 0 $ do - env' <- slistMapA numberConstant env + env' <- slistMapA genIds env snd <$> idana env' term - where - numberConstant :: STy t -> IdGen (ValId t) - numberConstant = \case - STNil -> pure VINil - STPair a b -> VIPair <$> numberConstant a <*> numberConstant b - STEither a b -> VIEither' <$> numberConstant a <*> numberConstant b - STMaybe{} -> VIThing <$> genId - STArr{} -> VIArr <$> genId - STScal{} -> VIScal <$> genId - STAccum{} -> VIAccum <$> genId idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) -idana env = \case +idana env expr = case expr of EVar _ t i -> do let v = slistIdx env i pure (v, EVar v t i) @@ -82,13 +93,13 @@ idana env = \case EFst _ e -> do (v, e') <- idana env e v' <- case v of VIPair v1 _ -> pure v1 - _ -> VIThing <$> genId + _ -> genIds (typeOf expr) pure (v', EFst v' e') ESnd _ e -> do (v, e') <- idana env e v' <- case v of VIPair _ v2 -> pure v2 - _ -> VIThing <$> genId + _ -> genIds (typeOf expr) pure (v', ESnd v' e') ENil _ -> pure (VINil, ENil VINil) @@ -104,33 +115,31 @@ idana env = \case pure (v, EInr v t1 e2') ECase _ e1 e2 e3 -> do + let STEither t1 t2 = typeOf e1 (v1, e1') <- idana env e1 case v1 of VIEither (Left v1') -> do (v2, e2') <- idana (v1' `SCons` env) e2 - scrap <- VIThing <$> genId + scrap <- genIds t2 (_, e3') <- idana (scrap `SCons` env) e3 pure (v2, ECase v2 e1' e2' e3') VIEither (Right v1') -> do - scrap <- VIThing <$> genId + scrap <- genIds t1 (_, e2') <- idana (scrap `SCons` env) e2 (v3, e3') <- idana (v1' `SCons` env) e3 pure (v3, ECase v3 e1' e2' e3') VIEither' v1'l v1'r -> do (_, e2') <- idana (v1'l `SCons` env) e2 (_, e3') <- idana (v1'r `SCons` env) e3 - res <- genId - pure (VIThing res, ECase (VIThing res) e1' e2' e3') - VIThing _ -> do - x2 <- genId - x3 <- genId - (v2, e2') <- idana (VIThing x2 `SCons` env) e2 - (v3, e3') <- idana (VIThing x3 `SCons` env) e3 - if v2 == v3 - then pure (v2, ECase v2 e1' e2' e3') - else do - res <- genId - pure (VIThing res, ECase (VIThing res) e1' e2' e3') + res <- genIds (typeOf expr) + pure (res, ECase res e1' e2' e3') + VIThing _ _ -> do + x2 <- genIds t1 + x3 <- genIds t2 + (v2, e2') <- idana (x2 `SCons` env) e2 + (v3, e3') <- idana (x3 `SCons` env) e3 + res <- unify v2 v3 + pure (res, ECase res e1' e2' e3') ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t) @@ -140,26 +149,29 @@ idana env = \case pure (v, EJust v e1') EMaybe _ e1 e2 e3 -> do + let STMaybe t1 = typeOf e3 (v3, e3') <- idana env e3 case v3 of VIMaybe Nothing -> do (v1, e1') <- idana env e1 - scrap <- VIThing <$> genId + scrap <- genIds t1 (_, e2') <- idana (scrap `SCons` env) e2 pure (v1, EMaybe v1 e1' e2' e3') VIMaybe (Just v3j) -> do (v2, e2') <- idana (v3j `SCons` env) e2 (_, e1') <- idana env e1 pure (v2, EMaybe v2 e1' e2' e3') - VIThing _ -> do + VIMaybe' v3' -> do + (v2, e2') <- idana (v3' `SCons` env) e2 (v1, e1') <- idana env e1 - scrap <- VIThing <$> genId + res <- unify v1 v2 + pure (res, EMaybe res e1' e2' e3') + VIThing _ _ -> do + (v1, e1') <- idana env e1 + scrap <- genIds t1 (v2, e2') <- idana (scrap `SCons` env) e2 - if v1 == v2 - then pure (v2, EMaybe v2 e1' e2' e3') - else do - res <- genId - pure (VIThing res, EMaybe (VIThing res) e1' e2' e3') + res <- unify v1 v2 + pure (res, EMaybe res e1' e2' e3') EConstArr _ dim t arr -> do x1 <- VIArr <$> genId @@ -167,15 +179,16 @@ idana env = \case EBuild _ dim e1 e2 -> do (_, e1') <- idana env e1 - scrap <- VIThing <$> genId - (_, e2') <- idana (scrap `SCons` env) e2 + x1 <- genIds (tTup (sreplicate dim tIx)) + (_, e2') <- idana (x1 `SCons` env) e2 res <- VIArr <$> genId pure (res, EBuild res dim e1' e2') EFold1Inner _ e1 e2 e3 -> do - scrap1 <- VIThing <$> genId - scrap2 <- VIThing <$> genId - (_, e1') <- idana (scrap1 `SCons` scrap2 `SCons` env) e1 + let t1 = typeOf e1 + x1 <- genIds t1 + x2 <- genIds t1 + (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 res <- VIArr <$> genId @@ -213,51 +226,53 @@ idana env = \case EIdx0 _ e1 -> do (_, e1') <- idana env e1 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EIdx0 res e1') EIdx1 _ e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EIdx1 res e1' e2') EIdx _ e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EIdx res e1' e2') EShape _ e1 -> do (_, e1') <- idana env e1 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EShape res e1') - EOp _ op e1 -> do + EOp _ (op :: SOp a t) e1 -> do (_, e1') <- idana env e1 - res <- VIThing <$> genId + res <- genIds (typeOf expr) pure (res, EOp res op e1') ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do - x1 <- VIThing <$> genId - x2 <- VIThing <$> genId + let t4 = typeOf e1 + x1 <- genIds t2 + x2 <- genIds t1 (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1 - x3 <- VIThing <$> genId - x4 <- VIThing <$> genId + x3 <- genIds (d1 t2) + x4 <- genIds (d1 t1) (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2 - x5 <- VIThing <$> genId - x6 <- VIThing <$> genId + x5 <- genIds (d2 t4) + x6 <- genIds t3 (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3 (_, e4') <- idana env e4 (_, e5') <- idana env e5 - res <- VIThing <$> genId + res <- genIds t4 pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') EWith _ e1 e2 -> do + let t1 = typeOf e1 (_, e1') <- idana env e1 x1 <- VIAccum <$> genId (v2, e2') <- idana (x1 `SCons` env) e2 - x2 <- VIThing <$> genId + x2 <- genIds t1 let res = VIPair v2 x2 pure (res, EWith res e1' e2') @@ -265,25 +280,66 @@ idana env = \case (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - res <- VIThing <$> genId - pure (res, EAccum res i e1' e2' e3') + pure (VINil, EAccum VINil i e1' e2' e3') EZero _ t -> do - res <- VIThing <$> genId + res <- genIds (d2 t) pure (res, EZero res t) EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (d2 t) pure (res, EPlus res t e1' e2') EOneHot _ t i e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- VIThing <$> genId + res <- genIds (d2 t) pure (res, EOneHot res t i e1' e2') EError _ t s -> do - res <- VIThing <$> genId + res <- genIds t pure (res, EError res t s) + +-- | This value might be either of the two arguments; we don't know which. +unify :: ValId t -> ValId t -> IdGen (ValId t) +unify VINil VINil = pure VINil +unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d +unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b +unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b +unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b +unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a +unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c +unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c +unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b +unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c +unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d +unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing +unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b +unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VIArr i) (VIArr j) | i == j = pure $ VIArr i + | otherwise = VIArr <$> genId +unify (VIScal i) (VIScal j) | i == j = pure $ VIScal i + | otherwise = VIScal <$> genId +unify (VIAccum i) (VIAccum j) | i == j = pure $ VIAccum i + | otherwise = VIAccum <$> genId +unify (VIThing t i) (VIThing _ j) | i == j = pure $ VIThing t i + | otherwise = genIds t +unify (VIThing t _) _ = genIds t +unify _ (VIThing t _) = genIds t + +genIds :: STy t -> IdGen (ValId t) +genIds STNil = pure VINil +genIds (STPair a b) = VIPair <$> genIds a <*> genIds b +genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STMaybe t) = VIMaybe' <$> genIds t +genIds STArr{} = VIArr <$> genId +genIds STScal{} = VIScal <$> genId +genIds STAccum{} = VIAccum <$> genId |