diff options
author | Tom Smeding <tom@tomsmeding.com> | 2021-06-24 23:14:54 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2021-06-24 23:14:54 +0200 |
commit | 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (patch) | |
tree | 0efeffb8b1b6d6126bc806209a2f5a64fb32c96f /Simplify.hs |
Initial
Diffstat (limited to 'Simplify.hs')
-rw-r--r-- | Simplify.hs | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/Simplify.hs b/Simplify.hs new file mode 100644 index 0000000..9ceaef9 --- /dev/null +++ b/Simplify.hs @@ -0,0 +1,203 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module Simplify ( + simplify, + simplifyFix, +) where + +import Data.Bifunctor +import Data.GADT.Compare +import qualified Data.Kind as Kind +import Data.List (find) +import Data.Type.Equality + +import AST +import Sink + + +data family Info (env :: [Kind.Type]) a +-- data instance Info env Int = InfoInt +-- data instance Info env Bool = InfoBool +-- data instance Info env Double = InfoDouble +data instance Info env (Array sh t) = InfoArray (Exp env sh) +-- data instance Info env () = InfoNil +data instance Info env (a, b) = InfoPair (Info env a) (Info env b) +-- data instance Info env (a -> b) = InfoFun + +data IEnv env where + ITop :: IEnv env + ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env) + +sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a +sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e) +sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b) +sinkInfo1 _ _ = error "Unknown info in sinkInfo1" + +iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a) +iprj ITop _ = Nothing +iprj (ICons t m _) Zero = (t,) <$> m +iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i + +simplifyFix :: Exp env a -> Exp env a +simplifyFix e = + let maxTimes = 4 + es = take (maxTimes + 1) (iterate simplify e) + pairs = zip es (tail es) + in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of + Just (e', _) -> e' + Nothing -> error "Simplification doesn't converge!" + +simplify :: Exp env a -> Exp env a +simplify = fst . simplify' ITop + +simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a)) +simplify' env = \case + App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing) + Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing) + Var t i -> (Var t i, snd <$> iprj env i) + Let arg e -> + let (arg', info) = simplify' env arg + env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env + in (simplifyLet arg' (fst (simplify' env' e)), Nothing) + Lit l -> (Lit l, Nothing) + Cond a b c -> + (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Const c -> (Const c, Nothing) + Pair a b -> + let (a', ia) = simplify' env a + (b', ib) = simplify' env b + in (Pair a' b', InfoPair <$> ia <*> ib) + Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e) + Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e) + Build sht a b -> + let a' = fst (simplify' env a) + in (Build sht a' (fst (simplify' env b)), Just (InfoArray a')) + Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing) + Shape e -> + case simplify' env e of + (_, Just (InfoArray she)) -> (she, Nothing) + (e', _) -> (Shape e', Nothing) + +simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b +simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b)) +simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b)) +simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b)) +simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b)) +simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b)) +simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b)) +simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b)) +simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b)) +simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a)) +simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a)) +simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a)) +simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a)) +simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b)) +simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b)) +simplifyApp (Const (CEq _)) (Pair a b) + | Just Refl <- geq a b + = Lit (LBool True) +simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b)) +simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b)) +simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a)) + +simplifyApp (Lam _ e) arg + | isDuplicable arg || countOcc Zero e <= 1 + = simplify (subst arg e) +simplifyApp (Lam _ e) arg = simplifyLet arg e + +simplifyApp a b = App a b + +simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b +simplifyLet arg e + | isDuplicable arg || countOcc Zero e <= 1 + = simplify (subst arg e) +simplifyLet (Pair a b) e = + simplifyLet a $ + simplifyLet (sinkExp1 b) $ + subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero)) + (Var (typeof b) Zero) + Succ i -> Var t (Succ (Succ i))) + e +simplifyLet (Cond c a b) e + | isDuplicable a && isDuplicable b + = simplifyLet c $ + (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) + Succ i -> Var t (Succ i)) + e) +simplifyLet a b = Let (simplify a) (simplify b) + +simplifyFst :: Exp env (a, b) -> Exp env a +simplifyFst (Pair e _) = e +simplifyFst (Let a e) = simplifyLet a (simplifyFst e) +simplifyFst e = Fst e + +simplifySnd :: Exp env (a, b) -> Exp env b +simplifySnd (Pair _ e) = e +simplifySnd (Let a e) = simplifyLet a (simplifySnd e) +simplifySnd e = Snd e + +simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a +simplifyIndex (Build _ _ f) e = simplifyApp f e +simplifyIndex a e = Index a e + +isDuplicable :: Exp env a -> Bool +isDuplicable (Lam _ e) = isDuplicable e +isDuplicable (Var _ _) = True +isDuplicable (Let a e) = isDuplicable a && isDuplicable e +isDuplicable (Lit (LInt _)) = True +isDuplicable (Lit (LBool _)) = True +isDuplicable (Lit (LDouble _)) = True +isDuplicable (Lit (LShape _)) = True +isDuplicable (Lit LNil) = True +isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2) +isDuplicable (Const _) = True +isDuplicable (Pair a b) = isDuplicable a && isDuplicable b +isDuplicable (Fst e) = isDuplicable e +isDuplicable (Snd e) = isDuplicable e +isDuplicable _ = False + +countOcc :: Idx env t -> Exp env a -> Int +countOcc i (App a b) = countOcc i a + countOcc i b +countOcc i (Lam _ e) = countOcc (Succ i) e +countOcc i (Var _ j) + | Just Refl <- geq i j = 1 + | otherwise = 0 +countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b +countOcc _ (Lit _) = 0 +countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c +countOcc _ (Const _) = 0 +countOcc i (Pair a b) = countOcc i a + countOcc i b +countOcc i (Fst e) = countOcc i e +countOcc i (Snd e) = countOcc i e +countOcc i (Build _ a b) = countOcc i a + countOcc i b +countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c +countOcc i (Index a b) = countOcc i a + countOcc i b +countOcc i (Shape e) = countOcc i e + +subst :: Exp env t -> Exp (t ': env) a -> Exp env a +subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e + +subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a +subst' f (App a b) = App (subst' f a) (subst' f b) +subst' f (Lam t e) = + Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e) +subst' f (Var t i) = f t i +subst' f (Let a b) = + Let (subst' f a) + (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b) +subst' _ (Lit l) = Lit l +subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c) +subst' _ (Const c) = Const c +subst' f (Pair a b) = Pair (subst' f a) (subst' f b) +subst' f (Fst e) = Fst (subst' f e) +subst' f (Snd e) = Snd (subst' f e) +subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b) +subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c) +subst' f (Index a b) = Index (subst' f a) (subst' f b) +subst' f (Shape e) = Shape (subst' f e) |