{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} module Analysis.Identity ( identityAnalysis, ) where import Data.Foldable (toList) import Data.List (intercalate) import AST import AST.Pretty (PrettyX(..)) import CHAD.Types (d1, d2) import Data import Util.IdGen -- | Every array, scalar and accumulator has an ID. Trivial values such as -- Nothing only have the knowledge that they are indeed Nothing. Compound -- values know which values they consist of. data ValId t where VINil :: ValId TNil VIPair :: ValId a -> ValId b -> ValId (TPair a b) VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) instance PrettyX ValId where prettyX = \case VINil -> "" VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")" VIEither (Left a) -> "(L" ++ prettyX a ++ ")" VIEither (Right a) -> "(R" ++ prettyX a ++ ")" VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")" VIMaybe Nothing -> "N" VIMaybe (Just a) -> 'J' : prettyX a VIMaybe' a -> 'M' : prettyX a VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" VIScal i -> show i VIAccum i -> 'C' : show i -- | Symbolic partial evaluation. identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t identityAnalysis env term = runIdGen 0 $ do env' <- slistMapA genIds env snd <$> idana env' term idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) idana env expr = case expr of EVar _ t i -> do let v = slistIdx env i pure (v, EVar v t i) ELet _ e1 e2 -> do (v1, e1') <- idana env e1 (v2, e2') <- idana (v1 `SCons` env) e2 pure (v2, ELet v2 e1' e2') EPair _ e1 e2 -> do (v1, e1') <- idana env e1 (v2, e2') <- idana env e2 pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2') EFst _ e -> do (v, e') <- idana env e let VIPair v1 _ = v pure (v1, EFst v1 e') ESnd _ e -> do (v, e') <- idana env e let VIPair _ v2 = v pure (v2, ESnd v2 e') ENil _ -> pure (VINil, ENil VINil) EInl _ t2 e1 -> do (v1, e1') <- idana env e1 let v = VIEither (Left v1) pure (v, EInl v t2 e1') EInr _ t1 e2 -> do (v2, e2') <- idana env e2 let v = VIEither (Right v2) pure (v, EInr v t1 e2') ECase _ e1 e2 e3 -> do let STEither t1 t2 = typeOf e1 (v1, e1') <- idana env e1 case v1 of VIEither (Left v1') -> do (v2, e2') <- idana (v1' `SCons` env) e2 scrap <- genIds t2 (_, e3') <- idana (scrap `SCons` env) e3 pure (v2, ECase v2 e1' e2' e3') VIEither (Right v1') -> do scrap <- genIds t1 (_, e2') <- idana (scrap `SCons` env) e2 (v3, e3') <- idana (v1' `SCons` env) e3 pure (v3, ECase v3 e1' e2' e3') VIEither' v1'l v1'r -> do (_, e2') <- idana (v1'l `SCons` env) e2 (_, e3') <- idana (v1'r `SCons` env) e3 res <- genIds (typeOf expr) pure (res, ECase res e1' e2' e3') ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t) EJust _ e1 -> do (v1, e1') <- idana env e1 let v = VIMaybe (Just v1) pure (v, EJust v e1') EMaybe _ e1 e2 e3 -> do let STMaybe t1 = typeOf e3 (v3, e3') <- idana env e3 case v3 of VIMaybe Nothing -> do (v1, e1') <- idana env e1 scrap <- genIds t1 (_, e2') <- idana (scrap `SCons` env) e2 pure (v1, EMaybe v1 e1' e2' e3') VIMaybe (Just v3j) -> do (v2, e2') <- idana (v3j `SCons` env) e2 (_, e1') <- idana env e1 pure (v2, EMaybe v2 e1' e2' e3') VIMaybe' v3' -> do (v2, e2') <- idana (v3' `SCons` env) e2 (v1, e1') <- idana env e1 res <- unify v1 v2 pure (res, EMaybe res e1' e2' e3') EConstArr _ dim t arr -> do x1 <- VIArr <$> genId <*> vecReplicateA dim genId pure (x1, EConstArr x1 dim t arr) EBuild _ dim e1 e2 -> do (shids, e1') <- idana env e1 x1 <- genIds (tTup (sreplicate dim tIx)) (_, e2') <- idana (x1 `SCons` env) e2 res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') EFold1Inner _ e1 e2 e3 -> do let t1 = typeOf e1 x1 <- genIds t1 x2 <- genIds t1 (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 res <- VIArr <$> genId <*> pure sh pure (res, EFold1Inner res e1' e2' e3') ESum1Inner _ e1 -> do (v1, e1') <- idana env e1 let VIArr _ (_ :< sh) = v1 res <- VIArr <$> genId <*> pure sh pure (res, ESum1Inner res e1') EUnit _ e1 -> do (_, e1') <- idana env e1 res <- VIArr <$> genId <*> pure VNil pure (res, EUnit res e1') EReplicate1Inner _ e1 e2 -> do (v1, e1') <- idana env e1 let VIScal v1' = v1 (v2, e2') <- idana env e2 let VIArr _ sh = v2 res <- VIArr <$> genId <*> pure (v1' :< sh) pure (res, EReplicate1Inner res e1' e2') EMaximum1Inner _ e1 -> do (v1, e1') <- idana env e1 let VIArr _ (_ :< sh) = v1 res <- VIArr <$> genId <*> pure sh pure (res, EMaximum1Inner res e1') EMinimum1Inner _ e1 -> do (v1, e1') <- idana env e1 let VIArr _ (_ :< sh) = v1 res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') EConst _ t val -> do res <- VIScal <$> genId pure (res, EConst res t val) EIdx0 _ e1 -> do (_, e1') <- idana env e1 res <- genIds (typeOf expr) pure (res, EIdx0 res e1') EIdx1 _ e1 e2 -> do (v1, e1') <- idana env e1 let VIArr _ sh = v1 (_, e2') <- idana env e2 res <- VIArr <$> genId <*> pure (vecInit sh) pure (res, EIdx1 res e1' e2') EIdx _ e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 res <- genIds (typeOf expr) pure (res, EIdx res e1' e2') EShape _ e1 -> do let STArr dim _ = typeOf e1 (v1, e1') <- idana env e1 let VIArr _ sh = v1 res = vecToShids dim sh pure (res, EShape res e1') EOp _ (op :: SOp a t) e1 -> do (_, e1') <- idana env e1 res <- genIds (typeOf expr) pure (res, EOp res op e1') ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do let t4 = typeOf e1 x1 <- genIds t2 x2 <- genIds t1 (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1 x3 <- genIds (d1 t2) x4 <- genIds (d1 t1) (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2 x5 <- genIds (d2 t4) x6 <- genIds t3 (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3 (_, e4') <- idana env e4 (_, e5') <- idana env e5 res <- genIds t4 pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') EWith _ e1 e2 -> do let t1 = typeOf e1 (_, e1') <- idana env e1 x1 <- VIAccum <$> genId (v2, e2') <- idana (x1 `SCons` env) e2 x2 <- genIds t1 let res = VIPair v2 x2 pure (res, EWith res e1' e2') EAccum _ i e1 e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 pure (VINil, EAccum VINil i e1' e2' e3') EZero _ t -> do res <- genIds (d2 t) pure (res, EZero res t) EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 res <- genIds (d2 t) pure (res, EPlus res t e1' e2') EOneHot _ t i e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 res <- genIds (d2 t) pure (res, EOneHot res t i e1' e2') EError _ t s -> do res <- genIds t pure (res, EError res t s) -- | This value might be either of the two arguments; we don't know which. unify :: ValId t -> ValId t -> IdGen (ValId t) unify VINil VINil = pure VINil unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j unifyID :: Int -> Int -> IdGen Int unifyID i j | i == j = pure i | otherwise = genId genIds :: STy t -> IdGen (ValId t) genIds STNil = pure VINil genIds (STPair a b) = VIPair <$> genIds a <*> genIds b genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) shidsToVec SZ _ = pure VNil shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx)) vecToShids SZ VNil = VINil vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i)