{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Simplify ( simplify, simplifyFix, simbeta, simpair, simindex, simifold1, simfix, SimList(..), ) where import Data.Bifunctor import Data.GADT.Compare import qualified Data.Kind as Kind import Data.List (find) import Data.Type.Equality import Debug.Trace import AST import Count import Sink data Bound a = Inclusive a | Exclusive a deriving (Show) instance Functor Bound where fmap f (Inclusive x) = Inclusive (f x) fmap f (Exclusive x) = Exclusive (f x) data family Info (env :: [Kind.Type]) a -- | Lower bound (inclusive), upper bound (inclusive/exclusive) data instance Info env Int = InfoInt (Maybe (Exp env Int)) (Maybe (Bound (Exp env Int))) data instance Info env (Array sh t) = InfoArray (Exp env sh) data instance Info env () = InfoNil data instance Info env (a, b) = InfoPair (Maybe (Info env a)) (Maybe (Info env b)) data IEnv env where ITop :: IEnv env ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env) showsInfo :: Int -> Type a -> Info env a -> ShowS showsInfo d TInt (InfoInt a b) = showParen (d > 10) $ showString "InfoInt " . showsPrec 11 a . showString " " . showsPrec 11 b showsInfo d TArray{} (InfoArray a) = showParen (d > 10) $ showString "InfoArray " . showsPrec 11 a showsInfo _ TNil InfoNil = showString "InfoNil" showsInfo d (TPair t1 t2) (InfoPair a b) = showParen (d > 10) $ showString "InfoPair " . showsInfo' 11 t1 a . showString " " . showsInfo' 11 t2 b showsInfo _ _ _ = error "showsInfo: No definition" showsInfo' :: Int -> Type a -> Maybe (Info env a) -> ShowS showsInfo' _ _ Nothing = showString "Nothing" showsInfo' d t (Just x) = showParen (d > 10) $ showString "Just " . showsInfo 11 t x sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a sinkInfo1 TInt (InfoInt a b) = InfoInt (sinkExp1 <$> a) (fmap sinkExp1 <$> b) sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e) sinkInfo1 TNil InfoNil = InfoNil sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 <$> a) (sinkInfo1 t2 <$> b) sinkInfo1 _ _ = error "Unknown info in sinkInfo1" iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a) iprj ITop _ = Nothing iprj (ICons t m _) Zero = (t,) <$> m iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i simplifyFix :: Exp env a -> Exp env a simplifyFix e = let maxTimes = 4 es = take (maxTimes + 1) (iterate simplify e) pairs = zip es (tail es) in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of Just (e', _) -> e' Nothing -> error "Simplification doesn't converge!" simplify :: Exp env a -> Exp env a simplify = fst . simplify' ITop simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a)) simplify' env = \case App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing) Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing) Var t i -> (Var t i, snd <$> iprj env i) Let arg e -> let (arg', info) = simplify' env arg env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env in (simplifyLet arg' (fst (simplify' env' e)), Nothing) Lit (LInt n) -> (Lit (LInt n), Just (InfoInt (Just (Lit (LInt n))) (Just (Inclusive (Lit (LInt n)))))) Lit l -> (Lit l, Nothing) Cond a b c -> (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) Const c -> (Const c, Nothing) Pair a b -> let (a', ia) = simplify' env a (b', ib) = simplify' env b in (simplifyPair a' b', Just (InfoPair ia ib)) Fst e -> bimap simplifyFst (>>= (\(InfoPair i _) -> i)) (simplify' env e) Snd e -> bimap simplifySnd (>>= (\(InfoPair _ i) -> i)) (simplify' env e) Build sht a (Lam shty fe) -> let a' = fst (simplify' env a) env' = ICons shty (Just (shapeBoundInfo sht (sinkExp1 a'))) env in (Build sht a' (Lam shty (fst (simplify' env' fe))), Just (InfoArray a')) Build sht a b -> let a' = fst (simplify' env a) in (Build sht a' (fst (simplify' env b)), Just (InfoArray a')) Ifold sht a b c -> (simplifyIfold env sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing) Shape e -> case simplify' env e of (_, Just (InfoArray she)) -> (she, Nothing) (e', _) -> (Shape e', Nothing) Undef t -> (Undef t, Nothing) shapeBoundInfo :: ShapeType sh -> Exp env sh -> Info env sh shapeBoundInfo STZ _ = InfoNil shapeBoundInfo (STC sht) she = InfoPair (Just (shapeBoundInfo sht (Fst she))) (Just (InfoInt (Just (Lit (LInt 0))) (Just (Exclusive (Snd she))))) simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b)) simplifyApp (Const CAddI) (Pair a (Lit (LInt 0))) = a simplifyApp (Const CAddI) (Pair (Lit (LInt 0)) a) = a -- simplifyApp (Const CAddI) (Pair a b) | Just Refl <- geq a b = -- simplifyApp (Const CMulI) (Pair (Lit (LInt 2)) a) simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b)) simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b)) simplifyApp (Const CMulI) (Pair a (Lit (LInt 1))) = a simplifyApp (Const CMulI) (Pair (Lit (LInt 1)) a) = a simplifyApp (Const CMulI) (Pair _ (Lit (LInt 0))) = Lit (LInt 0) simplifyApp (Const CMulI) (Pair (Lit (LInt 0)) _) = Lit (LInt 0) simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b)) simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b)) simplifyApp (Const CAddF) (Pair a (Lit (LDouble 0))) = a simplifyApp (Const CAddF) (Pair (Lit (LDouble 0)) a) = a -- simplifyApp (Const CAddF) (Pair a b) | Just Refl <- geq a b = -- simplifyApp (Const CMulF) (Pair (Lit (LDouble 2)) a) simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b)) simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b)) simplifyApp (Const CMulF) (Pair a (Lit (LDouble 1))) = a simplifyApp (Const CMulF) (Pair (Lit (LDouble 1)) a) = a simplifyApp (Const CMulF) (Pair _ (Lit (LDouble 0))) = Lit (LDouble 0) simplifyApp (Const CMulF) (Pair (Lit (LDouble 0)) _) = Lit (LDouble 0) simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b)) simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a)) simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a)) simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a)) simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a)) simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b)) simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b)) simplifyApp (Const (CEq _)) (Pair a b) | Just Refl <- geq a b = Lit (LBool True) simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b)) simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b)) simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a)) simplifyApp (Lam _ e) arg | isDuplicable arg || usesOf Zero e <= 1 = simplify (subst arg e) simplifyApp (Lam _ e) arg = simplifyLet arg e simplifyApp f (Cond c a b) = simplify $ Cond c (App f a) (App f b) simplifyApp a b = App a b simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b simplifyLet arg e | isDuplicable arg || usesOf Zero e <= 1 = simplify (subst arg e) simplifyLet (Pair a b) e = simplifyLet a $ simplifyLet (sinkExp1 b) $ subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero)) (Var (typeof b) Zero) Succ i -> Var t (Succ (Succ i))) e -- simplifyLet (Cond c a b) e -- | isDuplicable a && isDuplicable b -- = simplifyLet c $ -- (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) -- Succ i -> Var t (Succ i)) -- e) simplifyLet (Cond c a b) e = simplify $ Cond c (Let a e) (Let b e) simplifyLet a b = Let a b simplifyPair :: Exp env a -> Exp env b -> Exp env (a, b) simplifyPair (Cond c a b) d = simplify $ Cond c (Pair a d) (Pair b d) simplifyPair d (Cond c a b) = simplify $ Cond c (Pair d a) (Pair d b) simplifyPair a b = Pair a b simplifyFst :: Exp env (a, b) -> Exp env a simplifyFst (Pair e _) = e simplifyFst (Let a e) = simplifyLet a (simplifyFst e) simplifyFst (Cond c a b) = simplify $ Cond c (Fst a) (Fst b) simplifyFst e = Fst e simplifySnd :: Exp env (a, b) -> Exp env b simplifySnd (Pair _ e) = e simplifySnd (Let a e) = simplifyLet a (simplifySnd e) simplifySnd (Cond c a b) = simplify $ Cond c (Snd a) (Snd b) simplifySnd e = Snd e simplifyIfold :: IEnv env -> ShapeType sh -> Exp env ((a, sh) -> a) -> Exp env a -> Exp env sh -> Exp env a simplifyIfold env sht fe e0 she | Just res <- splitIfold sht fe e0 she = fst (simplify' env res) -- Given the following: -- ifold (\(a,i) -> if i == cmpref then val else a) _ she -- and given that we can prove that 0 <= cmpref < she and that 'val' is free, -- the whole fold can be replaced with 'val'. simplifyIfold env sht (Lam argty (Cond (App (Const (CEq _)) (Pair (Snd (Var _ Zero)) cmpref)) val (Fst (Var _ Zero)))) e0 she | let env' = ICons argty (Just (InfoPair Nothing (Just (shapeBoundInfo sht (sinkExp1 she))))) env , trace ("si: trying") True , proveShapeBound env' CLeI sht (zeroShapeExp sht) cmpref , trace ("si: prf1 = True") True , proveShapeBound env' CLtI sht cmpref (sinkExp1 she) , trace ("si: prf2 = True") True , trace ("si: cmpref = " ++ show val) True , usesOf Zero cmpref == 0 = simplifyLet (Pair e0 (subst (error "usesOf == 0 was wrong") cmpref)) val simplifyIfold _ sht fe e0 she = Ifold sht fe e0 she zeroShapeExp :: ShapeType sh -> Exp env sh zeroShapeExp STZ = Lit LNil zeroShapeExp (STC sht) = Pair (zeroShapeExp sht) (Lit (LInt 0)) proveShapeBound :: IEnv env -> Constant ((Int, Int) -> Bool) -> ShapeType sh -> Exp env sh -> Exp env sh -> Bool proveShapeBound _ _ STZ _ _ = True proveShapeBound env cmpop (STC sht) e1 e2 = let (_, info1) = simplify' env (Snd e1) (_, info2) = simplify' env (Snd e2) inclLo2 = case info2 of Just (InfoInt (Just lo2) _) -> lo2 _ -> Snd e2 -- this is also an inclusive lower bound, after all restresult = proveShapeBound env cmpop sht (Fst e1) (Fst e2) in restresult && case (cmpop, info1) of (CLeI, Just (InfoInt _ (Just (Inclusive hi1)))) -> proveLe hi1 inclLo2 (CLtI, Just (InfoInt _ (Just (Exclusive hi1)))) -> proveLe hi1 inclLo2 _ -> trace ("proveShapeBound: " ++ show cmpop ++ " (" ++ show e1 ++ ") (" ++ show e2 ++ ")") $ trace (" e1 = " ++ show e1) $ trace (" e2 = " ++ show e2) $ trace (" info1 = " ++ showsInfo' 0 TInt info1 "") $ trace (" info2 = " ++ showsInfo' 0 TInt info2 "") $ trace (" inclLo2 = " ++ show inclLo2) $ False proveLe :: Exp env Int -> Exp env Int -> Bool proveLe = \e1 e2 -> let res = proveLe' e1 e2 in trace ("proveLe: '" ++ show e1 ++ "' <= '" ++ show e2 ++ "' -> " ++ show res) res proveLe' :: Exp env Int -> Exp env Int -> Bool proveLe' e1 e2 | Just Refl <- geq e1 e2 = True proveLe' (Lit (LInt a)) (Lit (LInt b)) | a <= b = True proveLe' _ _ = False simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a simplifyIndex (Build _ _ f) e = simplifyApp f e simplifyIndex a e = Index a e isDuplicable :: Exp env a -> Bool isDuplicable (Lam _ e) = isDuplicable e isDuplicable (Var _ _) = True isDuplicable (Let a e) = isDuplicable a && isDuplicable e isDuplicable (Lit (LInt _)) = True isDuplicable (Lit (LBool _)) = True isDuplicable (Lit (LDouble _)) = True isDuplicable (Lit (LShape _)) = True isDuplicable (Lit LNil) = True isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2) isDuplicable (Const _) = True isDuplicable (Pair a b) = isDuplicable a && isDuplicable b isDuplicable (Fst e) = isDuplicable e isDuplicable (Snd e) = isDuplicable e isDuplicable _ = False subst :: Exp env t -> Exp (t ': env) a -> Exp env a subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a subst' f (App a b) = App (subst' f a) (subst' f b) subst' f (Lam t e) = Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e) subst' f (Var t i) = f t i subst' f (Let a b) = Let (subst' f a) (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b) subst' _ (Lit l) = Lit l subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c) subst' _ (Const c) = Const c subst' f (Pair a b) = Pair (subst' f a) (subst' f b) subst' f (Fst e) = Fst (subst' f e) subst' f (Snd e) = Snd (subst' f e) subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b) subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c) subst' f (Index a b) = Index (subst' f a) (subst' f b) subst' f (Shape e) = Shape (subst' f e) subst' _ (Undef t) = Undef t splitIfold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Maybe (Exp env s) splitIfold sht (Lam (TPair (TPair t1 t2) tidx) (Pair e1 e2)) e0 she | let uses1 = usesOf' PathStart Zero e1 uses2 = usesOf' PathStart Zero e2 , lycontract (lysnd (lyfst uses1)) == 0 , lycontract (lyfst (lyfst uses2)) == 0 -- Substitute the argument in e1 and t2 to refer to just the used -- components of their argument. To do this we reconstruct the original, -- partially unused argument by putting an 'Undef' in the unused spot. , let e1' = subst' (\t -> \case Zero -> Pair (Pair (Fst (Var (TPair t1 tidx) Zero)) (Undef t2)) (Snd (Var (TPair t1 tidx) Zero)) Succ i -> Var t (Succ i)) e1 e2' = subst' (\t -> \case Zero -> Pair (Pair (Undef t1) (Fst (Var (TPair t2 tidx) Zero))) (Snd (Var (TPair t2 tidx) Zero)) Succ i -> Var t (Succ i)) e2 = Just $ Let e0 $ Let (sinkExp1 she) $ Pair (Ifold sht (sinkExp2 (Lam (TPair t1 tidx) e1')) (Fst (Var (TPair t1 t2) (Succ Zero))) (Var tidx Zero)) (Ifold sht (sinkExp2 (Lam (TPair t2 tidx) e2')) (Snd (Var (TPair t1 t2) (Succ Zero))) (Var tidx Zero)) splitIfold _ _ _ _ = Nothing simbeta :: Exp env a -> Exp env a simbeta = \case App (Lam _ e) a | isDuplicable a || usesOf Zero e <= 1 -> simbeta (subst a e) | otherwise -> Let (simbeta a) (simbeta e) Let a e | isDuplicable a || usesOf Zero e <= 1 -> simbeta (subst a e) e -> simrecurse simbeta e simpair :: Exp env a -> Exp env a simpair = \case Fst (Pair a _) -> simpair a Snd (Pair _ b) -> simpair b Fst (Let a b) -> Let (simpair a) (simpair (Fst b)) Snd (Let a b) -> Let (simpair a) (simpair (Snd b)) e -> simrecurse simpair e simindex :: Exp env a -> Exp env a simindex = \case Index (Build _ _ f) e -> App (simindex f) (simindex e) e -> simrecurse simindex e simifold1 :: Exp env a -> Exp env a simifold1 = \case Ifold sht fe e0 she | Just res <- splitIfold sht (simifold1 fe) (simifold1 e0) (simifold1 she) -> res e -> simrecurse simifold1 e simrecurse :: (forall env' a'. Exp env' a' -> Exp env' a') -> Exp env a -> Exp env a simrecurse f = \case App a b -> App (f a) (f b) Lam t e -> Lam t (f e) Var t i -> Var t i Let a e -> Let (f a) (f e) Lit l -> Lit l Cond a b c -> Cond (f a) (f b) (f c) Const c -> Const c Pair a b -> Pair (f a) (f b) Fst e -> Fst (f e) Snd e -> Snd (f e) Build sht a b -> Build sht (f a) (f b) Ifold sht a b c -> Ifold sht (f a) (f b) (f c) Index a b -> Index (f a) (f b) Shape e -> Shape (f e) Undef t -> Undef t infixr :| data SimList = (forall env' a'. Exp env' a' -> Exp env' a') :| SimList | SimEnd simfix :: SimList -> Exp env a -> Exp env a simfix list = \e -> let e' = looponce list e in case geq e e' of Just Refl -> e' Nothing -> simfix list e' where looponce :: SimList -> Exp env a -> Exp env a looponce SimEnd e = e looponce (f :| l) e = looponce l (f e)