From d4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 27 Jun 2021 18:34:35 +0200 Subject: Things --- AD.hs | 1 + AST.hs | 58 +++++++++++-- Count.hs | 94 ++++++++++++++++++++ Eval.hs | 128 +++++++++++++++++++++++++++ Examples.hs | 15 ++-- Language.hs | 78 ++++++++++++++--- Pretty.hs | 229 ++++++++++++++++++++++++++++++++++++++++++++++++ Repl.hs | 2 + Simplify.hs | 280 +++++++++++++++++++++++++++++++++++++++++++++++++++-------- Sink.hs | 1 + ftilde.cabal | 4 + 11 files changed, 827 insertions(+), 63 deletions(-) create mode 100644 Count.hs create mode 100644 Eval.hs create mode 100644 Pretty.hs diff --git a/AD.hs b/AD.hs index 76fefe4..c9ac72e 100644 --- a/AD.hs +++ b/AD.hs @@ -99,6 +99,7 @@ ad' env = \case | TArray sht _ <- typeof e , Refl <- prfDualSht sht -> Shape (ad' env e) + Undef t -> Undef (dual t) convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a)) convIdx ETop i = Left i diff --git a/AST.hs b/AST.hs index 7e9c69c..3e1d2f6 100644 --- a/AST.hs +++ b/AST.hs @@ -1,6 +1,8 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST where @@ -26,6 +28,7 @@ data Exp env a where Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a Shape :: Exp env (Array sh a) -> Exp env sh + Undef :: Type a -> Exp env a data Constant a where CAddI :: Constant ((Int, Int) -> Int) @@ -42,6 +45,7 @@ data Constant a where CRound :: Constant (Double -> Int) CLtI :: Constant ((Int, Int) -> Bool) + CLeI :: Constant ((Int, Int) -> Bool) CLtF :: Constant ((Double, Double) -> Bool) CEq :: Type a -> Constant ((a, a) -> Bool) CAnd :: Constant ((Bool, Bool) -> Bool) @@ -72,11 +76,11 @@ data Literal a where data Shape sh where Z :: Shape () - (:.) :: Int -> Shape sh -> Shape (Int, sh) + (:.) :: Shape sh -> Int -> Shape (sh, Int) data ShapeType sh where STZ :: ShapeType () - STC :: ShapeType sh -> ShapeType (Int, sh) + STC :: ShapeType sh -> ShapeType (sh, Int) data Array sh a where Array :: Shape sh -> Type a -> Vector a -> Array sh a @@ -86,15 +90,19 @@ deriving instance Show (Constant a) deriving instance Show (Type a) deriving instance Show (Idx env a) deriving instance Show (Literal a) -deriving instance Show (Shape a) deriving instance Show (ShapeType a) +instance Show (Shape sh) where + showsPrec _ Z = showString "Z" + showsPrec p (sh :. n) = showParen (p > 0) $ + showsPrec 10 sh . showString " :. " . shows n + instance Show (Array sh a) where showsPrec p (Array sh t v) = showParen (p > 10) $ showString "Array " - . showsPrec 11 sh - . showsPrec 11 t + . showsPrec 11 sh . showString " " + . showsPrec 11 t . showString " " . (case typeHasShow t of Just Has -> showsPrec 11 v Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]")) @@ -134,6 +142,8 @@ instance GEq (Exp env) where geq Index{} _ = Nothing geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl geq Shape{} _ = Nothing + geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl + geq Undef{} _ = Nothing instance GEq Constant where geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing @@ -149,6 +159,7 @@ instance GEq Constant where geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing geq CRound CRound = Just Refl ; geq CRound _ = Nothing geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing + geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing @@ -187,7 +198,7 @@ instance GEq Literal where instance GEq Shape where geq Z Z = Just Refl - geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl + geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl geq _ _ = Nothing instance GEq ShapeType where @@ -195,17 +206,36 @@ instance GEq ShapeType where geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl geq _ _ = Nothing +-- Requires that the given term is neither 'Lam' nor 'Let'. +recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s +recurseMon f = \case + App a b -> f a <> f b + Lam _ _ -> error "recurseMon: Given Lam" + Var _ _ -> mempty + Let _ _ -> error "recurseMon: Given Let" + Lit _ -> mempty + Cond a b c -> f a <> f b <> f c + Const _ -> mempty + Pair a b -> f a <> f b + Fst a -> f a + Snd a -> f a + Build _ a b -> f a <> f b + Ifold _ a b c -> f a <> f b <> f c + Index a b -> f a <> f b + Shape a -> f a + Undef _ -> mempty + shapeType :: Shape sh -> ShapeType sh shapeType Z = STZ -shapeType (_ :. sh) = STC (shapeType sh) +shapeType (sh :. _) = STC (shapeType sh) shapeType' :: Shape sh -> Type sh shapeType' Z = TNil -shapeType' (_ :. sh) = TPair TInt (shapeType' sh) +shapeType' (sh :. _) = TPair (shapeType' sh) TInt shapeTypeType :: ShapeType sh -> Type sh shapeTypeType STZ = TNil -shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht) +shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt literalType :: Literal a -> Type a literalType LInt{} = TInt @@ -230,6 +260,7 @@ constType CExp = TFun TDouble TDouble constType CtoF = TFun TInt TDouble constType CRound = TFun TDouble TInt constType CLtI = TFun (TPair TInt TInt) TBool +constType CLeI = TFun (TPair TInt TInt) TBool constType CLtF = TFun (TPair TDouble TDouble) TBool constType (CEq t) = TFun (TPair t t) TBool constType CAnd = TFun (TPair TBool TBool) TBool @@ -251,6 +282,15 @@ typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t typeof (Ifold _ _ e _) = typeof e typeof (Index e _) = let TArray _ t = typeof e in t typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht +typeof (Undef t) = t + +idxToInt :: Idx env a -> Int +idxToInt Zero = 0 +idxToInt (Succ i) = idxToInt i + 1 + +shtToInt :: ShapeType sh -> Int +shtToInt STZ = 0 +shtToInt (STC sht) = shtToInt sht + 1 data Has c a where Has :: c a => Has c a diff --git a/Count.hs b/Count.hs new file mode 100644 index 0000000..d9fe661 --- /dev/null +++ b/Count.hs @@ -0,0 +1,94 @@ +{-# LANGUAGE GADTs #-} +module Count where + +import Data.GADT.Compare +import Data.Type.Equality + +import AST + + +data Layout = LyLeaf Integer | LyPair Layout Layout + deriving (Show) + +instance Semigroup Layout where + LyLeaf a <> LyLeaf b = LyLeaf (a + b) + l@(LyLeaf _) <> LyPair l1 l2 = LyPair (l <> l1) (l <> l2) + LyPair l1 l2 <> l@(LyLeaf _) = LyPair (l1 <> l) (l2 <> l) + LyPair l1 l2 <> LyPair l3 l4 = LyPair (l1 <> l3) (l2 <> l4) + +instance Monoid Layout where + mempty = LyLeaf 0 + +-- | Returns the maximum usage count of any component of the layout. +lycontract :: Layout -> Integer +lycontract (LyLeaf n) = n +lycontract (LyPair ly1 ly2) = max (lycontract ly1) (lycontract ly2) + +-- | Given the usage count of a pair, return the usage count of its first component +lyfst :: Layout -> Layout +lyfst (LyLeaf n) = LyLeaf n +lyfst (LyPair ly _) = ly + +-- | Given the usage count of a pair, return the usage count of its second component +lysnd :: Layout -> Layout +lysnd (LyLeaf n) = LyLeaf n +lysnd (LyPair _ ly) = ly + +-- | Returns the usage count of an expression given its usage counts in two +-- mutually exclusive program branches. +lymax :: Layout -> Layout -> Layout +lymax (LyLeaf n) (LyLeaf m) = LyLeaf (max n m) +lymax (LyPair a1 a2) (LyPair b1 b2) = LyPair (lymax a1 b1) (lymax a2 b2) +lymax (LyLeaf n) ly@LyPair{} = lymax (LyPair (LyLeaf n) (LyLeaf n)) ly +lymax ly@LyPair{} (LyLeaf n) = lymax ly (LyPair (LyLeaf n) (LyLeaf n)) + +-- | Count the uses of a variable in an expression +usesOf :: Idx env t -> Exp env a -> Integer +usesOf x e = lycontract (usesOf' PathStart x e) + +-- Path upwards in the tuple, starting at Start. +data Path part t where + PathFst :: Path part t -> Path part (t, a) + PathSnd :: Path part t -> Path part (a, t) + PathStart :: Path part part + +-- | Count the uses of the components of a variable in an expression +usesOf' :: Path large a -> Idx env t -> Exp env a -> Layout +usesOf' _ i (App a b) = usesOf' PathStart i a <> usesOf' PathStart i b +usesOf' _ i (Lam _ e) = usesOf' PathStart (Succ i) e +usesOf' pt i (Var _ i') = + let leaf = case geq i i' of Just Refl -> LyLeaf 1 + Nothing -> mempty + build :: Path a large -> Layout -> Layout + build (PathFst pt') ly = build pt' (LyPair ly mempty) + build (PathSnd pt') ly = build pt' (LyPair mempty ly) + build PathStart ly = ly + in build pt leaf +usesOf' _ i (Let a b) = usesOf' PathStart i a <> usesOf' PathStart (Succ i) b +usesOf' _ _ (Lit _) = mempty +usesOf' pt i (Cond p l r) = + usesOf' PathStart i p <> lymax (usesOf' pt i l) (usesOf' pt i r) +usesOf' _ _ (Const _) = mempty +usesOf' _ i (Pair a b) = usesOf' PathStart i a <> usesOf' PathStart i b +usesOf' pt i (Fst e) = usesOf' (PathFst pt) i e +usesOf' pt i (Snd e) = usesOf' (PathSnd pt) i e +-- TODO: Huge hack that allows arbitrary computational complexity increase. +-- The code current counts usages in the lambdas in Build and Ifold _once_, +-- whereas those usages are actually evaluated many times. The correct fix +-- would be to analyse whether the accessed indexes of the counted variable are +-- all disjoint, but that is hard. +usesOf' _ i (Build _ she fe) = usesOf' PathStart i she <> usesOf' PathStart i fe +-- usesOf' _ i (Build STZ she fe) = usesOf' PathStart i she <> usesOf' PathStart i fe +-- usesOf' _ i (Build _ she fe) = usesOf' PathStart i she <> mul2 (usesOf' PathStart i fe) +usesOf' _ i (Ifold _ fe e0 she) = + usesOf' PathStart i fe <> usesOf' PathStart i e0 <> usesOf' PathStart i she +-- usesOf' _ i (Ifold STZ fe e0 she) = +-- usesOf' PathStart i fe <> usesOf' PathStart i e0 <> usesOf' PathStart i she +-- usesOf' _ i (Ifold _ fe e0 she) = +-- mul2 (usesOf' PathStart i fe) <> usesOf' PathStart i e0 <> usesOf' PathStart i she +usesOf' _ i (Index a b) = usesOf' PathStart i a <> usesOf' PathStart i b +usesOf' _ i (Shape e) = usesOf' PathStart i e +usesOf' _ _ (Undef _) = mempty + +-- mul2 :: Semigroup m => Layout m -> Layout m +-- mul2 ly = ly <> ly diff --git a/Eval.hs b/Eval.hs new file mode 100644 index 0000000..19265bc --- /dev/null +++ b/Eval.hs @@ -0,0 +1,128 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +module Eval ( + eval, +) where + +import Data.List (foldl') +import qualified Data.Vector as V + +import AST + + +data Val env where + Top :: Val '[] + Push :: Val env -> a -> Val (a ': env) + +prj :: Val env -> Idx env a -> a +prj (Push _ x) Zero = x +prj (Push env _) (Succ i) = prj env i + +eval :: Exp '[] a -> a +eval = eval' Top + +eval' :: forall env a. Val env -> Exp env a -> a +eval' env = \case + App f a -> rec f (rec a) + Lam _ e -> \x -> eval' (Push env x) e + Var _ i -> prj env i + Let a e -> eval' (Push env (rec a)) e + Lit l -> evalL l + Cond c a b -> if rec c then rec a else rec b + Const c -> evalC c + Pair a b -> (rec a, rec b) + Fst e -> fst (rec e) + Snd e -> snd (rec e) + Build sht she fe -> + let TFun _ ty = typeof fe + in build ty sht (rec she) (rec fe) + Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she) + Index a i -> index (rec a) (rec i) + Shape a -> shape (rec a) + Undef t -> error ("eval: Undef of type " ++ show t) + where rec :: Exp env t -> t + rec = eval' env + +evalL :: Literal a -> a +evalL (LInt n) = n +evalL (LBool b) = b +evalL (LDouble d) = d +evalL (LArray a) = a +evalL (LShape Z) = () +evalL (LShape sh) = unshape sh +evalL LNil = () +evalL (LPair a b) = (evalL a, evalL b) + +evalC :: Constant a -> a +evalC CAddI = uncurry (+) +evalC CSubI = uncurry (-) +evalC CMulI = uncurry (*) +evalC CDivI = uncurry div +evalC CAddF = uncurry (+) +evalC CSubF = uncurry (-) +evalC CMulF = uncurry (*) +evalC CDivF = uncurry (/) +evalC CLog = log +evalC CExp = exp +evalC CtoF = fromIntegral +evalC CRound = round +evalC CLtI = uncurry (<) +evalC CLeI = uncurry (<=) +evalC CLtF = uncurry (<) +evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==) + | otherwise = error ("eval: Cannot Eq compare values of type " ++ show t) +evalC CAnd = uncurry (&&) +evalC COr = uncurry (||) +evalC CNot = not + +build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a +build ty sht sh f = + let sh' = toshape sht sh + in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i))) + +ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a +ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh)) + +index :: Array sh a -> sh -> a +index (Array sh _ v) idx = v V.! tolinear sh idx + +shape :: Array sh a -> sh +shape (Array sh _ _) = unshape sh + +enumshape :: Shape sh -> [sh] +enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh)) + where + next :: Shape sh -> sh -> sh + next Z () = () + next (sh' :. n) (idx, i) + | i < n = (idx, i + 1) + | otherwise = (next sh' idx, 0) + + zeroshape :: Shape sh -> sh + zeroshape Z = () + zeroshape (sh' :. _) = (zeroshape sh', 0) + +unshape :: Shape sh -> sh +unshape Z = () +unshape (sh :. n) = (unshape sh, n) + +toshape :: ShapeType sh -> sh -> Shape sh +toshape STZ () = Z +toshape (STC sht) (sh, n) = toshape sht sh :. n + +tolinear :: Shape sh -> sh -> Int +tolinear Z () = 0 +tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i + +fromlinear :: Shape sh -> Int -> sh +fromlinear Z _ = () +fromlinear (sh :. n) i = + let (q, r) = i `divMod` n + in (fromlinear sh q, r) + +shapesize :: Shape sh -> Int +shapesize Z = 1 +shapesize (sh :. n) = n * shapesize sh diff --git a/Examples.hs b/Examples.hs index 9d9cda7..0719f1f 100644 --- a/Examples.hs +++ b/Examples.hs @@ -4,11 +4,11 @@ import AST import qualified Language as L -sumSq :: Exp env (Array (Int, ()) Double -> Double) +sumSq :: Exp env (Array L.DIM1 Double -> Double) sumSq = Lam (TArray (STC STZ) TDouble) (L.sum (App mapSq (Var (TArray (STC STZ) TDouble) Zero))) -mapSq :: Exp env (Array (Int, ()) Double -> Array (Int, ()) Double) +mapSq :: Exp env (Array L.DIM1 Double -> Array L.DIM1 Double) mapSq = Lam (TArray (STC STZ) TDouble) (L.map (Lam TDouble @@ -16,13 +16,16 @@ mapSq = (Pair (Var TDouble Zero) (Var TDouble Zero)))) (Var (TArray (STC STZ) TDouble) Zero)) -mapSqIota :: Exp env (Array (Int, ()) Double) +mapSqIota :: Exp env (Array L.DIM1 Double) mapSqIota = L.map (Lam TDouble (App (Const CMulF) (Pair (Var TDouble Zero) (Var TDouble Zero)))) (Build (STC STZ) - (Pair (Lit (LInt 5)) (Lit LNil)) - (Lam (TPair TInt TNil) + (Pair (Lit LNil) (Lit (LInt 5))) + (Lam L.infer (App (Const CtoF) - (Fst (Var (TPair TInt TNil) Zero))))) + (Snd (L.var Zero))))) + +transpose2 :: Exp env (Array L.DIM2 Double -> Array L.DIM2 Double) +transpose2 = Lam L.infer (App (L.transpose L.infer) (App (L.transpose L.infer) (L.var Zero))) diff --git a/Language.hs b/Language.hs index e16cf7c..8ab6199 100644 --- a/Language.hs +++ b/Language.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} {-| This module is intended to be imported qualified, perhaps as @L@. -} module Language where @@ -7,6 +9,28 @@ import AST import Sink +-- Convention: matrices are represented in row-major: (((), y), x) +type DIM0 = () +type DIM1 = (DIM0, Int) +type DIM2 = (DIM1, Int) +type DIM3 = (DIM2, Int) + +class InferType a where infer :: Type a +instance InferType Int where infer = TInt +instance InferType Bool where infer = TBool +instance InferType Double where infer = TDouble +instance (InferType a, InferShapeType sh) => InferType (Array sh a) where infer = TArray inferST infer +instance InferType () where infer = TNil +instance (InferType a, InferType b) => InferType (a, b) where infer = TPair infer infer +instance (InferType a, InferType b) => InferType (a -> b) where infer = TFun infer infer + +class InferShapeType sh where inferST :: ShapeType sh +instance InferShapeType () where inferST = STZ +instance InferShapeType sh => InferShapeType (sh, Int) where inferST = STC inferST + +var :: InferType a => Idx env a -> Exp env a +var = Var infer + map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b) map f e = let ty@(TArray sht _) = typeof e @@ -18,17 +42,16 @@ map f e = (Index (Var ty (Succ Zero)) (Var sht' Zero))))) -sum :: Exp env (Array (Int, ()) Double) -> Exp env Double +sum :: Exp env (Array DIM1 Double) -> Exp env Double sum e = - let ty@(TArray sht _) = typeof e - in Let e - (Ifold sht - (Lam (TPair TDouble (TPair TInt TNil)) - (App (Const CAddF) (Pair - (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero)) - (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero)))))) - (Lit (LDouble 0)) - (Shape (Var ty Zero))) + Let e + (Ifold inferST + (Lam (TPair TDouble (TPair TNil TInt)) + (App (Const CAddF) (Pair + (Fst (var Zero)) + (Index (var (Succ Zero)) (Snd (var Zero)))))) + (Lit (LDouble 0)) + (Shape (var Zero))) -- | The two input arrays are assumed to be the same size. zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b)) @@ -50,3 +73,38 @@ oneHot sht sh idx = (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx))) (Lit (LDouble 1)) (Lit (LDouble 0)))) + +transpose :: Type a -> Exp env (Array DIM2 a -> Array DIM2 a) +transpose ty = + Lam (TArray inferST ty) + (Build inferST (Shape (Var (TArray inferST ty) Zero)) + (Lam infer (Index (Var (TArray inferST ty) (Succ Zero)) (Var infer Zero)))) + +eye :: Exp env (Int -> Array DIM2 Double) +eye = + Lam infer + (Build inferST (Pair (Pair (Lit LNil) (var Zero)) (var Zero)) + (Lam infer + (Cond (App (Const (CEq infer)) (Pair (Snd (var Zero)) (Snd (Fst (var Zero))))) + (Lit (LDouble 1)) + (Lit (LDouble 0))))) + +length :: Type a -> Exp env (Array DIM1 a -> Int) +length ty = Lam (TArray inferST ty) + (Snd (Shape (Var (TArray inferST ty) Zero))) + +vmmul :: Exp env (Array DIM1 Double -> Array DIM2 Double -> Array DIM1 Double) +vmmul = + Lam infer $ Lam infer $ + Build inferST + (Pair (Lit LNil) (Snd (Shape (var Zero)))) + (Lam infer $ + Ifold inferST + (Lam infer $ + App (Const CAddF) (Pair + (Fst (var Zero)) + (App (Const CMulF) (Pair + (Index (var (Succ (Succ (Succ Zero)))) (Snd (var Zero))) + (Index (var (Succ (Succ Zero))) (Pair (Pair (Lit LNil) (Snd (Snd (var Zero)))) (Snd (var (Succ Zero))))))))) + (Lit (LDouble 0)) + (Shape (var (Succ (Succ Zero))))) diff --git a/Pretty.hs b/Pretty.hs new file mode 100644 index 0000000..d63e3ce --- /dev/null +++ b/Pretty.hs @@ -0,0 +1,229 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeOperators #-} +module Pretty ( + prettyExp, + pprintExp, +) where + +import Data.Bifunctor +import Prettyprinter +import Prettyprinter.Render.String + +import AST + + +newtype IdGen a = IdGen { runIdGen :: Int -> (a, Int) } +instance Functor IdGen where + fmap f (IdGen g) = IdGen (first f . g) +instance Applicative IdGen where + pure x = IdGen (x,) + IdGen f <*> IdGen g = IdGen (\i -> let (f', j) = f i in first f' (g j)) +instance Monad IdGen where + IdGen f >>= g = IdGen (\i -> let (x, j) = f i in runIdGen (g x) j) + +evalIdGen :: Int -> IdGen a -> a +evalIdGen i = fst . ($ i) . runIdGen + +genId :: IdGen Int +genId = IdGen (\i -> (i, i + 1)) + +genName :: IdGen String +genName = ('x' :) . show <$> genId + +data PEnv env where + Top :: PEnv env + PCons :: String -> PEnv env -> PEnv (a ': env) + +prettyExp :: Exp env a -> String +prettyExp e = renderString (layoutSmart opts (evalIdGen 1 (pExp definfo 0 Top e))) + where opts = LayoutOptions (AvailablePerLine 120 0.7) + +pprintExp :: Exp env a -> IO () +pprintExp = putStrLn . prettyExp + +data Info = Info + { infoLamTypeSig :: Bool } + deriving (Show) + +definfo :: Info +definfo = Info True + +pExp :: forall env a x. Info -> Int -> PEnv env -> Exp env a -> IdGen (Doc x) +pExp thisinfo d env = \case + App (Const CAddF) (Pair a b) -> do + a' <- pExp definfo 7 env a + b' <- pExp definfo 7 env b + return (flatAlt (pParen (d > 10) $ pretty "AddF" <+> align (vsep [a', b'])) + (pParen (d > 6) $ hsep [a', pretty "+", b'])) + + App (Const CMulF) (Pair a b) -> do + a' <- pExp definfo 8 env a + b' <- pExp definfo 8 env b + return (flatAlt (pParen (d > 10) $ pretty "MulF" <+> align (vsep [a', b'])) + (pParen (d > 7) $ hsep [a', pretty "*", b'])) + + e@(App _ _) -> do + let collectAppsRev :: Exp env t -> IdGen (Doc x', [Doc x']) + collectAppsRev (App f a) = do + a' <- pExp definfo 11 env a + rest <- collectAppsRev f + return (fmap (a' :) rest) + collectAppsRev f = (,[]) <$> pExp definfo 11 env f + (func, rhss) <- collectAppsRev e + return (pParen (d > 10) $ func <+> align (sep (reverse rhss))) + + Lam t e -> do + name <- genName + let prefix | infoLamTypeSig thisinfo = + pretty ("\\(" ++ name ++ " :: " ++ showType 0 t ") ->") + | otherwise = + pretty ("\\" ++ name ++ " ->") + body <- pExp definfo 0 (PCons name env) e + return (pParen (d > 0) $ nest 2 (sep [prefix, body])) + + Var t i -> + case (env, i) of + (Top, _) -> return (pretty ("xUP_" ++ show (idxToInt i))) + (PCons name _, Zero) -> return (pretty name) + (PCons _ env', Succ i') -> pExp definfo d env' (Var t i') + + e@(Let _ _) -> do + let collectLets :: PEnv env' -> Exp env' t -> IdGen (Doc x', [Doc x']) + collectLets env' (Let rhs body) = do + name <- genName + rhs' <- (pretty (name ++ " = ") <>) . group <$> pExp definfo 0 env' rhs + rest <- collectLets (PCons name env') body + return (fmap (rhs' :) rest) + collectLets env' f = (,[]) <$> pExp definfo 0 env' f + (core, rhss) <- collectLets env e + return (pParen (d > 0) $ + align (vsep [pretty "let" <+> align (vsep rhss) + ,pretty "in" <+> group core])) + + Lit l -> return (pretty (showLit d l "")) + + Cond e1 e2 e3 -> do + e1' <- pExp definfo 11 env e1 + e2' <- pExp definfo 11 env e2 + e3' <- pExp definfo 11 env e3 + return (flatAlt (pParen (d > 10) $ pretty "cond" <+> align (vsep [e1', e2', e3'])) + (pParen (d > 0) $ hsep [e1', pretty "?", e2', pretty ":", e3'])) + + Const c -> return (pretty (showConst c)) + + Pair e1 e2 -> do + e1' <- pExp definfo 0 env e1 + e2' <- pExp definfo 0 env e2 + return (tupled [e1', e2']) + + Fst e -> do + e' <- pExp definfo 11 env e + return (pParen (d > 10) $ pretty "fst" <+> e') + + Snd e -> do + e' <- pExp definfo 11 env e + return (pParen (d > 10) $ pretty "snd" <+> e') + + Build sht e1 e2 -> do + e1' <- pExp definfo 11 env e1 + e2' <- pExp definfo{infoLamTypeSig=False} 11 env e2 + return (pParen (d > 10) $ + pretty "build" <+> align (sep + [pretty ("DIM" <> show (shtToInt sht)), e1', e2'])) + + Ifold sht e1 e2 e3 -> do + e1' <- pExp (definfo{infoLamTypeSig=False}) 11 env e1 + e2' <- pExp definfo 11 env e2 + e3' <- pExp definfo 11 env e3 + return (pParen (d > 10) $ + pretty "ifold" <+> align (sep + [pretty ("DIM" <> show (shtToInt sht)), e1', e2', e3'])) + + Index e1 e2 -> do + e1' <- pExp definfo 11 env e1 + e2' <- pExp definfo 11 env e2 + return (pParen (d > 10) $ + flatAlt (pretty "index" <+> align (sep [e1', e2'])) + (hsep [e1', pretty "!", e2'])) + + Shape e -> do + e' <- pExp definfo 11 env e + return (pParen (d > 10) $ pretty "shape" <+> e') + + Undef t -> return (pParen (d > 0) $ pretty ("UNDEF :: " ++ showType 0 t "")) + +pParen :: Bool -> Doc x -> Doc x +pParen True = parens +pParen False = id + +showLit :: Int -> Literal a -> ShowS +showLit _ (LInt i) = shows i +showLit _ (LBool b) = shows b +showLit _ (LDouble d) = shows d +showLit d (LArray (Array sh t v)) + | Just Has <- typeHasShow t + = showParen (d > 0) $ + shows v . showString " :: Array " . showShape 11 sh . showString " " . showType 11 t + | otherwise + = showParen (d > 0) $ + showString "[{noshow}] :: Array " . showShape 11 sh . showString " " . showType 11 t +showLit d (LShape sh) = showShape d sh +showLit _ LNil = showString "()" +showLit _ (LPair a b) = + showString "(" . showLit 0 a . showString ", " . showLit 0 b . showString ")" + +showConst :: Constant a -> String +showConst CAddI = "AddI" +showConst CSubI = "SubI" +showConst CMulI = "MulI" +showConst CDivI = "DivI" +showConst CAddF = "AddF" +showConst CSubF = "SubF" +showConst CMulF = "MulF" +showConst CDivF = "DivF" +showConst CLog = "Log" +showConst CExp = "Exp" +showConst CtoF = "ToF" +showConst CRound = "Round" +showConst CLtI = "LtI" +showConst CLeI = "LeI" +showConst CLtF = "LtF" +showConst (CEq _) = "Eq" +showConst CAnd = "And" +showConst COr = "Or" +showConst CNot = "Not" + +showShape :: Int -> Shape sh -> ShowS +showShape _ Z = showString "Z" +showShape d (sh :. n) = showParen (d > 10) $ + showShape 10 sh . showString ":" . shows n + +showType :: Int -> Type a -> ShowS +showType _ TInt = showString "Int" +showType _ TBool = showString "Bool" +showType _ TDouble = showString "Double" +showType _ (TArray sht t) = + let n = shtToInt sht + in showString (replicate n '[') . showType 0 t . showString (replicate n ']') +showType _ TNil = showString "()" +showType _ (TPair a b) = + showString "(" . showType 0 a . showString ", " . showType 0 b . showString ")" +showType d (TFun a b) = showParen (d > 10) $ + showType 11 a . showString " -> " . showType 10 b + +-- showTypeShort :: Int -> Type a -> ShowS +-- showTypeShort _ TInt = showString "i" +-- showTypeShort _ TBool = showString "b" +-- showTypeShort _ TDouble = showString "d" +-- showTypeShort _ (TArray sht t) = +-- let n = shtToInt sht +-- in showString (replicate n '[') . showTypeShort 0 t . showString (replicate n ']') +-- showTypeShort _ TNil = showString "." +-- showTypeShort _ (TPair a b) = +-- showString "(" . showTypeShort 11 a . showTypeShort 11 b . showString ")" +-- showTypeShort d (TFun a b) = showParen (d > 10) $ +-- showTypeShort 11 a . showString " -> " . showTypeShort 10 b diff --git a/Repl.hs b/Repl.hs index 85fe7be..9e7cef2 100644 --- a/Repl.hs +++ b/Repl.hs @@ -7,5 +7,7 @@ import AD import Examples import Gradient import qualified Language as L +import Language (InferType(..), InferShapeType(..), DIM0, DIM1, DIM2, DIM3, var) +import Pretty import Simplify import Sink diff --git a/Simplify.hs b/Simplify.hs index 9ceaef9..e95043d 100644 --- a/Simplify.hs +++ b/Simplify.hs @@ -2,12 +2,15 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Simplify ( simplify, simplifyFix, + simbeta, simpair, simindex, simifold1, + simfix, SimList(..), ) where import Data.Bifunctor @@ -16,26 +19,51 @@ import qualified Data.Kind as Kind import Data.List (find) import Data.Type.Equality +import Debug.Trace + import AST +import Count import Sink +data Bound a = Inclusive a | Exclusive a + deriving (Show) + +instance Functor Bound where + fmap f (Inclusive x) = Inclusive (f x) + fmap f (Exclusive x) = Exclusive (f x) + data family Info (env :: [Kind.Type]) a --- data instance Info env Int = InfoInt --- data instance Info env Bool = InfoBool --- data instance Info env Double = InfoDouble +-- | Lower bound (inclusive), upper bound (inclusive/exclusive) +data instance Info env Int = InfoInt (Maybe (Exp env Int)) (Maybe (Bound (Exp env Int))) data instance Info env (Array sh t) = InfoArray (Exp env sh) --- data instance Info env () = InfoNil -data instance Info env (a, b) = InfoPair (Info env a) (Info env b) --- data instance Info env (a -> b) = InfoFun +data instance Info env () = InfoNil +data instance Info env (a, b) = InfoPair (Maybe (Info env a)) (Maybe (Info env b)) data IEnv env where ITop :: IEnv env ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env) +showsInfo :: Int -> Type a -> Info env a -> ShowS +showsInfo d TInt (InfoInt a b) = showParen (d > 10) $ + showString "InfoInt " . showsPrec 11 a . showString " " . showsPrec 11 b +showsInfo d TArray{} (InfoArray a) = showParen (d > 10) $ + showString "InfoArray " . showsPrec 11 a +showsInfo _ TNil InfoNil = showString "InfoNil" +showsInfo d (TPair t1 t2) (InfoPair a b) = showParen (d > 10) $ + showString "InfoPair " . showsInfo' 11 t1 a . showString " " . showsInfo' 11 t2 b +showsInfo _ _ _ = error "showsInfo: No definition" + +showsInfo' :: Int -> Type a -> Maybe (Info env a) -> ShowS +showsInfo' _ _ Nothing = showString "Nothing" +showsInfo' d t (Just x) = showParen (d > 10) $ + showString "Just " . showsInfo 11 t x + sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a +sinkInfo1 TInt (InfoInt a b) = InfoInt (sinkExp1 <$> a) (fmap sinkExp1 <$> b) sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e) -sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b) +sinkInfo1 TNil InfoNil = InfoNil +sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 <$> a) (sinkInfo1 t2 <$> b) sinkInfo1 _ _ = error "Unknown info in sinkInfo1" iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a) @@ -64,6 +92,8 @@ simplify' env = \case let (arg', info) = simplify' env arg env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env in (simplifyLet arg' (fst (simplify' env' e)), Nothing) + Lit (LInt n) -> (Lit (LInt n), Just (InfoInt (Just (Lit (LInt n))) + (Just (Inclusive (Lit (LInt n)))))) Lit l -> (Lit l, Nothing) Cond a b c -> (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) @@ -71,27 +101,55 @@ simplify' env = \case Pair a b -> let (a', ia) = simplify' env a (b', ib) = simplify' env b - in (Pair a' b', InfoPair <$> ia <*> ib) - Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e) - Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e) + in (simplifyPair a' b', Just (InfoPair ia ib)) + Fst e -> bimap simplifyFst (>>= (\(InfoPair i _) -> i)) (simplify' env e) + Snd e -> bimap simplifySnd (>>= (\(InfoPair _ i) -> i)) (simplify' env e) + Build sht a (Lam shty fe) -> + let a' = fst (simplify' env a) + env' = ICons shty (Just (shapeBoundInfo sht (sinkExp1 a'))) env + in (Build sht a' (Lam shty (fst (simplify' env' fe))), Just (InfoArray a')) Build sht a b -> let a' = fst (simplify' env a) in (Build sht a' (fst (simplify' env b)), Just (InfoArray a')) - Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Ifold sht a b c -> + (simplifyIfold env sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing) Shape e -> case simplify' env e of (_, Just (InfoArray she)) -> (she, Nothing) (e', _) -> (Shape e', Nothing) + Undef t -> (Undef t, Nothing) + +shapeBoundInfo :: ShapeType sh -> Exp env sh -> Info env sh +shapeBoundInfo STZ _ = InfoNil +shapeBoundInfo (STC sht) she = + InfoPair (Just (shapeBoundInfo sht (Fst she))) + (Just (InfoInt (Just (Lit (LInt 0))) (Just (Exclusive (Snd she))))) simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b)) +simplifyApp (Const CAddI) (Pair a (Lit (LInt 0))) = a +simplifyApp (Const CAddI) (Pair (Lit (LInt 0)) a) = a +-- simplifyApp (Const CAddI) (Pair a b) | Just Refl <- geq a b = +-- simplifyApp (Const CMulI) (Pair (Lit (LInt 2)) a) simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b)) simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b)) +simplifyApp (Const CMulI) (Pair a (Lit (LInt 1))) = a +simplifyApp (Const CMulI) (Pair (Lit (LInt 1)) a) = a +simplifyApp (Const CMulI) (Pair _ (Lit (LInt 0))) = Lit (LInt 0) +simplifyApp (Const CMulI) (Pair (Lit (LInt 0)) _) = Lit (LInt 0) simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b)) simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b)) +simplifyApp (Const CAddF) (Pair a (Lit (LDouble 0))) = a +simplifyApp (Const CAddF) (Pair (Lit (LDouble 0)) a) = a +-- simplifyApp (Const CAddF) (Pair a b) | Just Refl <- geq a b = +-- simplifyApp (Const CMulF) (Pair (Lit (LDouble 2)) a) simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b)) simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b)) +simplifyApp (Const CMulF) (Pair a (Lit (LDouble 1))) = a +simplifyApp (Const CMulF) (Pair (Lit (LDouble 1)) a) = a +simplifyApp (Const CMulF) (Pair _ (Lit (LDouble 0))) = Lit (LDouble 0) +simplifyApp (Const CMulF) (Pair (Lit (LDouble 0)) _) = Lit (LDouble 0) simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b)) simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a)) simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a)) @@ -107,15 +165,17 @@ simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a | simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a)) simplifyApp (Lam _ e) arg - | isDuplicable arg || countOcc Zero e <= 1 + | isDuplicable arg || usesOf Zero e <= 1 = simplify (subst arg e) simplifyApp (Lam _ e) arg = simplifyLet arg e +simplifyApp f (Cond c a b) = simplify $ Cond c (App f a) (App f b) + simplifyApp a b = App a b simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b simplifyLet arg e - | isDuplicable arg || countOcc Zero e <= 1 + | isDuplicable arg || usesOf Zero e <= 1 = simplify (subst arg e) simplifyLet (Pair a b) e = simplifyLet a $ @@ -124,24 +184,89 @@ simplifyLet (Pair a b) e = (Var (typeof b) Zero) Succ i -> Var t (Succ (Succ i))) e -simplifyLet (Cond c a b) e - | isDuplicable a && isDuplicable b - = simplifyLet c $ - (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) - Succ i -> Var t (Succ i)) - e) -simplifyLet a b = Let (simplify a) (simplify b) +-- simplifyLet (Cond c a b) e +-- | isDuplicable a && isDuplicable b +-- = simplifyLet c $ +-- (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) +-- Succ i -> Var t (Succ i)) +-- e) +simplifyLet (Cond c a b) e = simplify $ Cond c (Let a e) (Let b e) +simplifyLet a b = Let a b + +simplifyPair :: Exp env a -> Exp env b -> Exp env (a, b) +simplifyPair (Cond c a b) d = simplify $ Cond c (Pair a d) (Pair b d) +simplifyPair d (Cond c a b) = simplify $ Cond c (Pair d a) (Pair d b) +simplifyPair a b = Pair a b simplifyFst :: Exp env (a, b) -> Exp env a simplifyFst (Pair e _) = e simplifyFst (Let a e) = simplifyLet a (simplifyFst e) +simplifyFst (Cond c a b) = simplify $ Cond c (Fst a) (Fst b) simplifyFst e = Fst e simplifySnd :: Exp env (a, b) -> Exp env b simplifySnd (Pair _ e) = e simplifySnd (Let a e) = simplifyLet a (simplifySnd e) +simplifySnd (Cond c a b) = simplify $ Cond c (Snd a) (Snd b) simplifySnd e = Snd e +simplifyIfold :: IEnv env -> ShapeType sh -> Exp env ((a, sh) -> a) -> Exp env a -> Exp env sh -> Exp env a +simplifyIfold env sht fe e0 she + | Just res <- splitIfold sht fe e0 she + = fst (simplify' env res) +-- Given the following: +-- ifold (\(a,i) -> if i == cmpref then val else a) _ she +-- and given that we can prove that 0 <= cmpref < she and that 'val' is free, +-- the whole fold can be replaced with 'val'. +simplifyIfold env sht (Lam argty (Cond (App (Const (CEq _)) (Pair (Snd (Var _ Zero)) cmpref)) val (Fst (Var _ Zero)))) e0 she + | let env' = ICons argty (Just (InfoPair Nothing (Just (shapeBoundInfo sht (sinkExp1 she))))) env + , trace ("si: trying") True + , proveShapeBound env' CLeI sht (zeroShapeExp sht) cmpref + , trace ("si: prf1 = True") True + , proveShapeBound env' CLtI sht cmpref (sinkExp1 she) + , trace ("si: prf2 = True") True + , trace ("si: cmpref = " ++ show val) True + , usesOf Zero cmpref == 0 + = simplifyLet (Pair e0 (subst (error "usesOf == 0 was wrong") cmpref)) val +simplifyIfold _ sht fe e0 she = Ifold sht fe e0 she + +zeroShapeExp :: ShapeType sh -> Exp env sh +zeroShapeExp STZ = Lit LNil +zeroShapeExp (STC sht) = Pair (zeroShapeExp sht) (Lit (LInt 0)) + +proveShapeBound :: IEnv env -> Constant ((Int, Int) -> Bool) -> ShapeType sh -> Exp env sh -> Exp env sh -> Bool +proveShapeBound _ _ STZ _ _ = True +proveShapeBound env cmpop (STC sht) e1 e2 = + let (_, info1) = simplify' env (Snd e1) + (_, info2) = simplify' env (Snd e2) + inclLo2 = case info2 of + Just (InfoInt (Just lo2) _) -> lo2 + _ -> Snd e2 -- this is also an inclusive lower bound, after all + restresult = proveShapeBound env cmpop sht (Fst e1) (Fst e2) + in restresult && case (cmpop, info1) of + (CLeI, Just (InfoInt _ (Just (Inclusive hi1)))) -> + proveLe hi1 inclLo2 + (CLtI, Just (InfoInt _ (Just (Exclusive hi1)))) -> + proveLe hi1 inclLo2 + _ -> trace ("proveShapeBound: " ++ show cmpop ++ " (" ++ show e1 ++ ") (" ++ show e2 ++ ")") $ + trace (" e1 = " ++ show e1) $ + trace (" e2 = " ++ show e2) $ + trace (" info1 = " ++ showsInfo' 0 TInt info1 "") $ + trace (" info2 = " ++ showsInfo' 0 TInt info2 "") $ + trace (" inclLo2 = " ++ show inclLo2) $ + False + +proveLe :: Exp env Int -> Exp env Int -> Bool +proveLe = \e1 e2 -> + let res = proveLe' e1 e2 + in trace ("proveLe: '" ++ show e1 ++ "' <= '" ++ show e2 ++ "' -> " ++ show res) + res + +proveLe' :: Exp env Int -> Exp env Int -> Bool +proveLe' e1 e2 | Just Refl <- geq e1 e2 = True +proveLe' (Lit (LInt a)) (Lit (LInt b)) | a <= b = True +proveLe' _ _ = False + simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a simplifyIndex (Build _ _ f) e = simplifyApp f e simplifyIndex a e = Index a e @@ -162,24 +287,6 @@ isDuplicable (Fst e) = isDuplicable e isDuplicable (Snd e) = isDuplicable e isDuplicable _ = False -countOcc :: Idx env t -> Exp env a -> Int -countOcc i (App a b) = countOcc i a + countOcc i b -countOcc i (Lam _ e) = countOcc (Succ i) e -countOcc i (Var _ j) - | Just Refl <- geq i j = 1 - | otherwise = 0 -countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b -countOcc _ (Lit _) = 0 -countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c -countOcc _ (Const _) = 0 -countOcc i (Pair a b) = countOcc i a + countOcc i b -countOcc i (Fst e) = countOcc i e -countOcc i (Snd e) = countOcc i e -countOcc i (Build _ a b) = countOcc i a + countOcc i b -countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c -countOcc i (Index a b) = countOcc i a + countOcc i b -countOcc i (Shape e) = countOcc i e - subst :: Exp env t -> Exp (t ': env) a -> Exp env a subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e @@ -201,3 +308,100 @@ subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b) subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c) subst' f (Index a b) = Index (subst' f a) (subst' f b) subst' f (Shape e) = Shape (subst' f e) +subst' _ (Undef t) = Undef t + +splitIfold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Maybe (Exp env s) +splitIfold sht (Lam (TPair (TPair t1 t2) tidx) (Pair e1 e2)) e0 she + | let uses1 = usesOf' PathStart Zero e1 + uses2 = usesOf' PathStart Zero e2 + , lycontract (lysnd (lyfst uses1)) == 0 + , lycontract (lyfst (lyfst uses2)) == 0 + -- Substitute the argument in e1 and t2 to refer to just the used + -- components of their argument. To do this we reconstruct the original, + -- partially unused argument by putting an 'Undef' in the unused spot. + , let e1' = subst' (\t -> \case Zero -> + Pair (Pair (Fst (Var (TPair t1 tidx) Zero)) + (Undef t2)) + (Snd (Var (TPair t1 tidx) Zero)) + Succ i -> Var t (Succ i)) + e1 + e2' = subst' (\t -> \case Zero -> + Pair (Pair (Undef t1) + (Fst (Var (TPair t2 tidx) Zero))) + (Snd (Var (TPair t2 tidx) Zero)) + Succ i -> Var t (Succ i)) + e2 + = Just $ + Let e0 $ Let (sinkExp1 she) $ + Pair (Ifold sht (sinkExp2 (Lam (TPair t1 tidx) e1')) + (Fst (Var (TPair t1 t2) (Succ Zero))) + (Var tidx Zero)) + (Ifold sht (sinkExp2 (Lam (TPair t2 tidx) e2')) + (Snd (Var (TPair t1 t2) (Succ Zero))) + (Var tidx Zero)) +splitIfold _ _ _ _ = Nothing + +simbeta :: Exp env a -> Exp env a +simbeta = \case + App (Lam _ e) a + | isDuplicable a || usesOf Zero e <= 1 + -> simbeta (subst a e) + | otherwise + -> Let (simbeta a) (simbeta e) + Let a e + | isDuplicable a || usesOf Zero e <= 1 + -> simbeta (subst a e) + e -> simrecurse simbeta e + +simpair :: Exp env a -> Exp env a +simpair = \case + Fst (Pair a _) -> simpair a + Snd (Pair _ b) -> simpair b + Fst (Let a b) -> Let (simpair a) (simpair (Fst b)) + Snd (Let a b) -> Let (simpair a) (simpair (Snd b)) + e -> simrecurse simpair e + +simindex :: Exp env a -> Exp env a +simindex = \case + Index (Build _ _ f) e -> + App (simindex f) (simindex e) + e -> simrecurse simindex e + +simifold1 :: Exp env a -> Exp env a +simifold1 = \case + Ifold sht fe e0 she + | Just res <- splitIfold sht (simifold1 fe) (simifold1 e0) (simifold1 she) + -> res + e -> simrecurse simifold1 e + +simrecurse :: (forall env' a'. Exp env' a' -> Exp env' a') -> Exp env a -> Exp env a +simrecurse f = \case + App a b -> App (f a) (f b) + Lam t e -> Lam t (f e) + Var t i -> Var t i + Let a e -> Let (f a) (f e) + Lit l -> Lit l + Cond a b c -> Cond (f a) (f b) (f c) + Const c -> Const c + Pair a b -> Pair (f a) (f b) + Fst e -> Fst (f e) + Snd e -> Snd (f e) + Build sht a b -> Build sht (f a) (f b) + Ifold sht a b c -> Ifold sht (f a) (f b) (f c) + Index a b -> Index (f a) (f b) + Shape e -> Shape (f e) + Undef t -> Undef t + +infixr :| +data SimList = (forall env' a'. Exp env' a' -> Exp env' a') :| SimList + | SimEnd + +simfix :: SimList -> Exp env a -> Exp env a +simfix list = \e -> let e' = looponce list e + in case geq e e' of + Just Refl -> e' + Nothing -> simfix list e' + where + looponce :: SimList -> Exp env a -> Exp env a + looponce SimEnd e = e + looponce (f :| l) e = looponce l (f e) diff --git a/Sink.hs b/Sink.hs index 1368cb6..c258dc5 100644 --- a/Sink.hs +++ b/Sink.hs @@ -39,6 +39,7 @@ sinkExp w = \case Ifold sht e1 e2 e3 -> Ifold sht (sinkExp w e1) (sinkExp w e2) (sinkExp w e3) Index e1 e2 -> Index (sinkExp w e1) (sinkExp w e2) Shape e -> Shape (sinkExp w e) + Undef t -> Undef t sinkExp1 :: Exp env a -> Exp (t ': env) a sinkExp1 = sinkExp (wSucc wId) diff --git a/ftilde.cabal b/ftilde.cabal index 1432547..04a5a2a 100644 --- a/ftilde.cabal +++ b/ftilde.cabal @@ -12,14 +12,18 @@ executable ftilde other-modules: AST AD + Count + Eval Examples Gradient Language + Pretty Repl Simplify Sink build-depends: base >= 4.13 && < 4.15, + prettyprinter >= 1.7.0 && < 1.8, vector, some hs-source-dirs: . -- cgit v1.2.3-70-g09d2