aboutsummaryrefslogtreecommitdiff
path: root/Simplify.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-24 23:14:54 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-24 23:14:54 +0200
commit0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (patch)
tree0efeffb8b1b6d6126bc806209a2f5a64fb32c96f /Simplify.hs
Initial
Diffstat (limited to 'Simplify.hs')
-rw-r--r--Simplify.hs203
1 files changed, 203 insertions, 0 deletions
diff --git a/Simplify.hs b/Simplify.hs
new file mode 100644
index 0000000..9ceaef9
--- /dev/null
+++ b/Simplify.hs
@@ -0,0 +1,203 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module Simplify (
+ simplify,
+ simplifyFix,
+) where
+
+import Data.Bifunctor
+import Data.GADT.Compare
+import qualified Data.Kind as Kind
+import Data.List (find)
+import Data.Type.Equality
+
+import AST
+import Sink
+
+
+data family Info (env :: [Kind.Type]) a
+-- data instance Info env Int = InfoInt
+-- data instance Info env Bool = InfoBool
+-- data instance Info env Double = InfoDouble
+data instance Info env (Array sh t) = InfoArray (Exp env sh)
+-- data instance Info env () = InfoNil
+data instance Info env (a, b) = InfoPair (Info env a) (Info env b)
+-- data instance Info env (a -> b) = InfoFun
+
+data IEnv env where
+ ITop :: IEnv env
+ ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env)
+
+sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a
+sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e)
+sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b)
+sinkInfo1 _ _ = error "Unknown info in sinkInfo1"
+
+iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a)
+iprj ITop _ = Nothing
+iprj (ICons t m _) Zero = (t,) <$> m
+iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i
+
+simplifyFix :: Exp env a -> Exp env a
+simplifyFix e =
+ let maxTimes = 4
+ es = take (maxTimes + 1) (iterate simplify e)
+ pairs = zip es (tail es)
+ in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of
+ Just (e', _) -> e'
+ Nothing -> error "Simplification doesn't converge!"
+
+simplify :: Exp env a -> Exp env a
+simplify = fst . simplify' ITop
+
+simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a))
+simplify' env = \case
+ App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing)
+ Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing)
+ Var t i -> (Var t i, snd <$> iprj env i)
+ Let arg e ->
+ let (arg', info) = simplify' env arg
+ env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env
+ in (simplifyLet arg' (fst (simplify' env' e)), Nothing)
+ Lit l -> (Lit l, Nothing)
+ Cond a b c ->
+ (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
+ Const c -> (Const c, Nothing)
+ Pair a b ->
+ let (a', ia) = simplify' env a
+ (b', ib) = simplify' env b
+ in (Pair a' b', InfoPair <$> ia <*> ib)
+ Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e)
+ Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e)
+ Build sht a b ->
+ let a' = fst (simplify' env a)
+ in (Build sht a' (fst (simplify' env b)), Just (InfoArray a'))
+ Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
+ Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing)
+ Shape e ->
+ case simplify' env e of
+ (_, Just (InfoArray she)) -> (she, Nothing)
+ (e', _) -> (Shape e', Nothing)
+
+simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b
+simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b))
+simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b))
+simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b))
+simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b))
+simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b))
+simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b))
+simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b))
+simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b))
+simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a))
+simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a))
+simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a))
+simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a))
+simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b))
+simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b))
+simplifyApp (Const (CEq _)) (Pair a b)
+ | Just Refl <- geq a b
+ = Lit (LBool True)
+simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b))
+simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b))
+simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a))
+
+simplifyApp (Lam _ e) arg
+ | isDuplicable arg || countOcc Zero e <= 1
+ = simplify (subst arg e)
+simplifyApp (Lam _ e) arg = simplifyLet arg e
+
+simplifyApp a b = App a b
+
+simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b
+simplifyLet arg e
+ | isDuplicable arg || countOcc Zero e <= 1
+ = simplify (subst arg e)
+simplifyLet (Pair a b) e =
+ simplifyLet a $
+ simplifyLet (sinkExp1 b) $
+ subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero))
+ (Var (typeof b) Zero)
+ Succ i -> Var t (Succ (Succ i)))
+ e
+simplifyLet (Cond c a b) e
+ | isDuplicable a && isDuplicable b
+ = simplifyLet c $
+ (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b)
+ Succ i -> Var t (Succ i))
+ e)
+simplifyLet a b = Let (simplify a) (simplify b)
+
+simplifyFst :: Exp env (a, b) -> Exp env a
+simplifyFst (Pair e _) = e
+simplifyFst (Let a e) = simplifyLet a (simplifyFst e)
+simplifyFst e = Fst e
+
+simplifySnd :: Exp env (a, b) -> Exp env b
+simplifySnd (Pair _ e) = e
+simplifySnd (Let a e) = simplifyLet a (simplifySnd e)
+simplifySnd e = Snd e
+
+simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a
+simplifyIndex (Build _ _ f) e = simplifyApp f e
+simplifyIndex a e = Index a e
+
+isDuplicable :: Exp env a -> Bool
+isDuplicable (Lam _ e) = isDuplicable e
+isDuplicable (Var _ _) = True
+isDuplicable (Let a e) = isDuplicable a && isDuplicable e
+isDuplicable (Lit (LInt _)) = True
+isDuplicable (Lit (LBool _)) = True
+isDuplicable (Lit (LDouble _)) = True
+isDuplicable (Lit (LShape _)) = True
+isDuplicable (Lit LNil) = True
+isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2)
+isDuplicable (Const _) = True
+isDuplicable (Pair a b) = isDuplicable a && isDuplicable b
+isDuplicable (Fst e) = isDuplicable e
+isDuplicable (Snd e) = isDuplicable e
+isDuplicable _ = False
+
+countOcc :: Idx env t -> Exp env a -> Int
+countOcc i (App a b) = countOcc i a + countOcc i b
+countOcc i (Lam _ e) = countOcc (Succ i) e
+countOcc i (Var _ j)
+ | Just Refl <- geq i j = 1
+ | otherwise = 0
+countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b
+countOcc _ (Lit _) = 0
+countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c
+countOcc _ (Const _) = 0
+countOcc i (Pair a b) = countOcc i a + countOcc i b
+countOcc i (Fst e) = countOcc i e
+countOcc i (Snd e) = countOcc i e
+countOcc i (Build _ a b) = countOcc i a + countOcc i b
+countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c
+countOcc i (Index a b) = countOcc i a + countOcc i b
+countOcc i (Shape e) = countOcc i e
+
+subst :: Exp env t -> Exp (t ': env) a -> Exp env a
+subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e
+
+subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a
+subst' f (App a b) = App (subst' f a) (subst' f b)
+subst' f (Lam t e) =
+ Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e)
+subst' f (Var t i) = f t i
+subst' f (Let a b) =
+ Let (subst' f a)
+ (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b)
+subst' _ (Lit l) = Lit l
+subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c)
+subst' _ (Const c) = Const c
+subst' f (Pair a b) = Pair (subst' f a) (subst' f b)
+subst' f (Fst e) = Fst (subst' f e)
+subst' f (Snd e) = Snd (subst' f e)
+subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b)
+subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c)
+subst' f (Index a b) = Index (subst' f a) (subst' f b)
+subst' f (Shape e) = Shape (subst' f e)