{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Simplify ( simplify, simplifyFix, ) where import Data.Bifunctor import Data.GADT.Compare import qualified Data.Kind as Kind import Data.List (find) import Data.Type.Equality import AST import Sink data family Info (env :: [Kind.Type]) a -- data instance Info env Int = InfoInt -- data instance Info env Bool = InfoBool -- data instance Info env Double = InfoDouble data instance Info env (Array sh t) = InfoArray (Exp env sh) -- data instance Info env () = InfoNil data instance Info env (a, b) = InfoPair (Info env a) (Info env b) -- data instance Info env (a -> b) = InfoFun data IEnv env where ITop :: IEnv env ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env) sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e) sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b) sinkInfo1 _ _ = error "Unknown info in sinkInfo1" iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a) iprj ITop _ = Nothing iprj (ICons t m _) Zero = (t,) <$> m iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i simplifyFix :: Exp env a -> Exp env a simplifyFix e = let maxTimes = 4 es = take (maxTimes + 1) (iterate simplify e) pairs = zip es (tail es) in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of Just (e', _) -> e' Nothing -> error "Simplification doesn't converge!" simplify :: Exp env a -> Exp env a simplify = fst . simplify' ITop simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a)) simplify' env = \case App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing) Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing) Var t i -> (Var t i, snd <$> iprj env i) Let arg e -> let (arg', info) = simplify' env arg env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env in (simplifyLet arg' (fst (simplify' env' e)), Nothing) Lit l -> (Lit l, Nothing) Cond a b c -> (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) Const c -> (Const c, Nothing) Pair a b -> let (a', ia) = simplify' env a (b', ib) = simplify' env b in (Pair a' b', InfoPair <$> ia <*> ib) Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e) Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e) Build sht a b -> let a' = fst (simplify' env a) in (Build sht a' (fst (simplify' env b)), Just (InfoArray a')) Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing) Shape e -> case simplify' env e of (_, Just (InfoArray she)) -> (she, Nothing) (e', _) -> (Shape e', Nothing) simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b)) simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b)) simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b)) simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b)) simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b)) simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b)) simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b)) simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b)) simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a)) simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a)) simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a)) simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a)) simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b)) simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b)) simplifyApp (Const (CEq _)) (Pair a b) | Just Refl <- geq a b = Lit (LBool True) simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b)) simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b)) simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a)) simplifyApp (Lam _ e) arg | isDuplicable arg || countOcc Zero e <= 1 = simplify (subst arg e) simplifyApp (Lam _ e) arg = simplifyLet arg e simplifyApp a b = App a b simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b simplifyLet arg e | isDuplicable arg || countOcc Zero e <= 1 = simplify (subst arg e) simplifyLet (Pair a b) e = simplifyLet a $ simplifyLet (sinkExp1 b) $ subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero)) (Var (typeof b) Zero) Succ i -> Var t (Succ (Succ i))) e simplifyLet (Cond c a b) e | isDuplicable a && isDuplicable b = simplifyLet c $ (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) Succ i -> Var t (Succ i)) e) simplifyLet a b = Let (simplify a) (simplify b) simplifyFst :: Exp env (a, b) -> Exp env a simplifyFst (Pair e _) = e simplifyFst (Let a e) = simplifyLet a (simplifyFst e) simplifyFst e = Fst e simplifySnd :: Exp env (a, b) -> Exp env b simplifySnd (Pair _ e) = e simplifySnd (Let a e) = simplifyLet a (simplifySnd e) simplifySnd e = Snd e simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a simplifyIndex (Build _ _ f) e = simplifyApp f e simplifyIndex a e = Index a e isDuplicable :: Exp env a -> Bool isDuplicable (Lam _ e) = isDuplicable e isDuplicable (Var _ _) = True isDuplicable (Let a e) = isDuplicable a && isDuplicable e isDuplicable (Lit (LInt _)) = True isDuplicable (Lit (LBool _)) = True isDuplicable (Lit (LDouble _)) = True isDuplicable (Lit (LShape _)) = True isDuplicable (Lit LNil) = True isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2) isDuplicable (Const _) = True isDuplicable (Pair a b) = isDuplicable a && isDuplicable b isDuplicable (Fst e) = isDuplicable e isDuplicable (Snd e) = isDuplicable e isDuplicable _ = False countOcc :: Idx env t -> Exp env a -> Int countOcc i (App a b) = countOcc i a + countOcc i b countOcc i (Lam _ e) = countOcc (Succ i) e countOcc i (Var _ j) | Just Refl <- geq i j = 1 | otherwise = 0 countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b countOcc _ (Lit _) = 0 countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c countOcc _ (Const _) = 0 countOcc i (Pair a b) = countOcc i a + countOcc i b countOcc i (Fst e) = countOcc i e countOcc i (Snd e) = countOcc i e countOcc i (Build _ a b) = countOcc i a + countOcc i b countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c countOcc i (Index a b) = countOcc i a + countOcc i b countOcc i (Shape e) = countOcc i e subst :: Exp env t -> Exp (t ': env) a -> Exp env a subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a subst' f (App a b) = App (subst' f a) (subst' f b) subst' f (Lam t e) = Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e) subst' f (Var t i) = f t i subst' f (Let a b) = Let (subst' f a) (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b) subst' _ (Lit l) = Lit l subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c) subst' _ (Const c) = Const c subst' f (Pair a b) = Pair (subst' f a) (subst' f b) subst' f (Fst e) = Fst (subst' f e) subst' f (Snd e) = Snd (subst' f e) subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b) subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c) subst' f (Index a b) = Index (subst' f a) (subst' f b) subst' f (Shape e) = Shape (subst' f e)