aboutsummaryrefslogtreecommitdiff
path: root/src/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'src/Analysis')
-rw-r--r--src/Analysis/Identity.hs97
1 files changed, 85 insertions, 12 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 5e36dde..b54946b 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -2,8 +2,12 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
module Analysis.Identity (
identityAnalysis,
+ identityAnalysis',
+ ValId(..),
+ validSplitEither,
) where
import Data.Foldable (toList)
@@ -24,11 +28,13 @@ data ValId t where
VIPair :: ValId a -> ValId b -> ValId (TPair a b)
VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
+ VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b)
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
+deriving instance Show (ValId t)
instance PrettyX ValId where
prettyX = \case
@@ -40,16 +46,31 @@ instance PrettyX ValId where
VIMaybe Nothing -> "N"
VIMaybe (Just a) -> 'J' : prettyX a
VIMaybe' a -> 'M' : prettyX a
+ VILEither (VIMaybe Nothing) -> "lN"
+ VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")"
+ VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")"
+ VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))"
VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
VIScal i -> show i
VIAccum i -> 'C' : show i
+validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b))
+validSplitEither (VIEither (Left v)) = (Just v, Nothing)
+validSplitEither (VIEither (Right v)) = (Nothing, Just v)
+validSplitEither (VIEither' v1 v2) = (Just v1, Just v2)
+
-- | Symbolic partial evaluation.
identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
identityAnalysis env term = runIdGen 0 $ do
env' <- slistMapA genIds env
snd <$> idana env' term
+identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t
+identityAnalysis' env term = snd (runIdGen 0 (idana env term))
+
idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
idana env expr = case expr of
EVar _ t i -> do
@@ -103,9 +124,9 @@ idana env expr = case expr of
(v3, e3') <- idana (v1' `SCons` env) e3
pure (v3, ECase v3 e1' e2' e3')
VIEither' v1'l v1'r -> do
- (_, e2') <- idana (v1'l `SCons` env) e2
- (_, e3') <- idana (v1'r `SCons` env) e3
- res <- genIds (typeOf expr)
+ (v2, e2') <- idana (v1'l `SCons` env) e2
+ (v3, e3') <- idana (v1'r `SCons` env) e3
+ res <- unify v2 v3
pure (res, ECase res e1' e2' e3')
ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t)
@@ -134,6 +155,42 @@ idana env expr = case expr of
res <- unify v1 v2
pure (res, EMaybe res e1' e2' e3')
+ ELNil _ t1 t2 -> do
+ let v = VILEither (VIMaybe Nothing)
+ pure (v, ELNil v t1 t2)
+
+ ELInl _ t2 e1 -> do
+ (v1, e1') <- idana env e1
+ let v = VILEither (VIMaybe (Just (VIEither (Left v1))))
+ pure (v, ELInl v t2 e1')
+
+ ELInr _ t1 e2 -> do
+ (v2, e2') <- idana env e2
+ let v = VILEither (VIMaybe (Just (VIEither (Right v2))))
+ pure (v, ELInr v t1 e2')
+
+ ELCase _ e1 e2 e3 e4 -> do
+ let STLEither t1 t2 = typeOf e1
+ (v1L, e1') <- idana env e1
+ let VILEither v1 = v1L
+ let go mv1'l mv1'r f = do
+ v1'l <- maybe (genIds t1) pure mv1'l
+ v1'r <- maybe (genIds t2) pure mv1'r
+ (v2, e2') <- idana env e2
+ (v3, e3') <- idana (v1'l `SCons` env) e3
+ (v4, e4') <- idana (v1'r `SCons` env) e4
+ res <- f v2 v3 v4
+ pure (res, ELCase res e1' e2' e3' e4')
+ case v1 of
+ VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2)
+ VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3)
+ VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4)
+ VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4)
+ VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3)
+ VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4)
+ VIMaybe' (VIEither' v1'l v1'r) ->
+ go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4)
+
EConstArr _ dim t arr -> do
x1 <- VIArr <$> genId <*> vecReplicateA dim genId
pure (x1, EConstArr x1 dim t arr)
@@ -145,7 +202,7 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> shidsToVec dim shids
pure (res, EBuild res dim e1' e2')
- EFold1Inner _ e1 e2 e3 -> do
+ EFold1Inner _ cm e1 e2 e3 -> do
let t1 = typeOf e1
x1 <- genIds t1
x2 <- genIds t1
@@ -154,7 +211,7 @@ idana env expr = case expr of
(v3, e3') <- idana env e3
let VIArr _ (_ :< sh) = v3
res <- VIArr <$> genId <*> pure sh
- pure (res, EFold1Inner res e1' e2' e3')
+ pure (res, EFold1Inner res cm e1' e2' e3')
ESum1Inner _ e1 -> do
(v1, e1') <- idana env e1
@@ -237,6 +294,10 @@ idana env expr = case expr of
res <- genIds t4
pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
+ ERecompute _ e -> do
+ (v, e') <- idana env e
+ pure (v, ERecompute v e')
+
EWith _ t e1 e2 -> do
let t1 = typeOf e1
(_, e1') <- idana env e1
@@ -246,26 +307,36 @@ idana env expr = case expr of
let res = VIPair v2 x2
pure (res, EWith res t e1' e2')
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
- pure (VINil, EAccum VINil t prj e1' e2' e3')
+ pure (VINil, EAccum VINil t prj e1' sp e2' e3')
- EZero _ t -> do
- res <- genIds (d2 t)
- pure (res, EZero res t)
+ EZero _ t e1 -> do
+ -- Approximate the result of EZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EZero res t e1')
+
+ EDeepZero _ t e1 -> do
+ -- Approximate the result of EDeepZero to be independent from the zero info
+ -- expression; not quite true for shape variables
+ (_, e1') <- idana env e1
+ res <- genIds (fromSMTy t)
+ pure (res, EDeepZero res t e1')
EPlus _ t e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (d2 t)
+ res <- genIds (fromSMTy t)
pure (res, EPlus res t e1' e2')
EOneHot _ t i e1 e2 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (d2 t)
+ res <- genIds (fromSMTy t)
pure (res, EOneHot res t i e1' e2')
EError _ t s -> do
@@ -294,6 +365,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
+unify (VILEither a) (VILEither b) = VILEither <$> unify a b
unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
@@ -306,6 +378,7 @@ genIds :: STy t -> IdGen (ValId t)
genIds STNil = pure VINil
genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
+genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b)
genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId