aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
commitd4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 (patch)
tree1ab301617043ac6df228ef617afa22633a01a671
parent0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (diff)
-rw-r--r--AD.hs1
-rw-r--r--AST.hs58
-rw-r--r--Count.hs94
-rw-r--r--Eval.hs128
-rw-r--r--Examples.hs15
-rw-r--r--Language.hs78
-rw-r--r--Pretty.hs229
-rw-r--r--Repl.hs2
-rw-r--r--Simplify.hs280
-rw-r--r--Sink.hs1
-rw-r--r--ftilde.cabal4
11 files changed, 827 insertions, 63 deletions
diff --git a/AD.hs b/AD.hs
index 76fefe4..c9ac72e 100644
--- a/AD.hs
+++ b/AD.hs
@@ -99,6 +99,7 @@ ad' env = \case
| TArray sht _ <- typeof e
, Refl <- prfDualSht sht
-> Shape (ad' env e)
+ Undef t -> Undef (dual t)
convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a))
convIdx ETop i = Left i
diff --git a/AST.hs b/AST.hs
index 7e9c69c..3e1d2f6 100644
--- a/AST.hs
+++ b/AST.hs
@@ -1,6 +1,8 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST where
@@ -26,6 +28,7 @@ data Exp env a where
Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s
Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a
Shape :: Exp env (Array sh a) -> Exp env sh
+ Undef :: Type a -> Exp env a
data Constant a where
CAddI :: Constant ((Int, Int) -> Int)
@@ -42,6 +45,7 @@ data Constant a where
CRound :: Constant (Double -> Int)
CLtI :: Constant ((Int, Int) -> Bool)
+ CLeI :: Constant ((Int, Int) -> Bool)
CLtF :: Constant ((Double, Double) -> Bool)
CEq :: Type a -> Constant ((a, a) -> Bool)
CAnd :: Constant ((Bool, Bool) -> Bool)
@@ -72,11 +76,11 @@ data Literal a where
data Shape sh where
Z :: Shape ()
- (:.) :: Int -> Shape sh -> Shape (Int, sh)
+ (:.) :: Shape sh -> Int -> Shape (sh, Int)
data ShapeType sh where
STZ :: ShapeType ()
- STC :: ShapeType sh -> ShapeType (Int, sh)
+ STC :: ShapeType sh -> ShapeType (sh, Int)
data Array sh a where
Array :: Shape sh -> Type a -> Vector a -> Array sh a
@@ -86,15 +90,19 @@ deriving instance Show (Constant a)
deriving instance Show (Type a)
deriving instance Show (Idx env a)
deriving instance Show (Literal a)
-deriving instance Show (Shape a)
deriving instance Show (ShapeType a)
+instance Show (Shape sh) where
+ showsPrec _ Z = showString "Z"
+ showsPrec p (sh :. n) = showParen (p > 0) $
+ showsPrec 10 sh . showString " :. " . shows n
+
instance Show (Array sh a) where
showsPrec p (Array sh t v) =
showParen (p > 10) $
showString "Array "
- . showsPrec 11 sh
- . showsPrec 11 t
+ . showsPrec 11 sh . showString " "
+ . showsPrec 11 t . showString " "
. (case typeHasShow t of
Just Has -> showsPrec 11 v
Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]"))
@@ -134,6 +142,8 @@ instance GEq (Exp env) where
geq Index{} _ = Nothing
geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl
geq Shape{} _ = Nothing
+ geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl
+ geq Undef{} _ = Nothing
instance GEq Constant where
geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing
@@ -149,6 +159,7 @@ instance GEq Constant where
geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing
geq CRound CRound = Just Refl ; geq CRound _ = Nothing
geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing
+ geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing
geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing
geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing
geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing
@@ -187,7 +198,7 @@ instance GEq Literal where
instance GEq Shape where
geq Z Z = Just Refl
- geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl
+ geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl
geq _ _ = Nothing
instance GEq ShapeType where
@@ -195,17 +206,36 @@ instance GEq ShapeType where
geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl
geq _ _ = Nothing
+-- Requires that the given term is neither 'Lam' nor 'Let'.
+recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s
+recurseMon f = \case
+ App a b -> f a <> f b
+ Lam _ _ -> error "recurseMon: Given Lam"
+ Var _ _ -> mempty
+ Let _ _ -> error "recurseMon: Given Let"
+ Lit _ -> mempty
+ Cond a b c -> f a <> f b <> f c
+ Const _ -> mempty
+ Pair a b -> f a <> f b
+ Fst a -> f a
+ Snd a -> f a
+ Build _ a b -> f a <> f b
+ Ifold _ a b c -> f a <> f b <> f c
+ Index a b -> f a <> f b
+ Shape a -> f a
+ Undef _ -> mempty
+
shapeType :: Shape sh -> ShapeType sh
shapeType Z = STZ
-shapeType (_ :. sh) = STC (shapeType sh)
+shapeType (sh :. _) = STC (shapeType sh)
shapeType' :: Shape sh -> Type sh
shapeType' Z = TNil
-shapeType' (_ :. sh) = TPair TInt (shapeType' sh)
+shapeType' (sh :. _) = TPair (shapeType' sh) TInt
shapeTypeType :: ShapeType sh -> Type sh
shapeTypeType STZ = TNil
-shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht)
+shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt
literalType :: Literal a -> Type a
literalType LInt{} = TInt
@@ -230,6 +260,7 @@ constType CExp = TFun TDouble TDouble
constType CtoF = TFun TInt TDouble
constType CRound = TFun TDouble TInt
constType CLtI = TFun (TPair TInt TInt) TBool
+constType CLeI = TFun (TPair TInt TInt) TBool
constType CLtF = TFun (TPair TDouble TDouble) TBool
constType (CEq t) = TFun (TPair t t) TBool
constType CAnd = TFun (TPair TBool TBool) TBool
@@ -251,6 +282,15 @@ typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t
typeof (Ifold _ _ e _) = typeof e
typeof (Index e _) = let TArray _ t = typeof e in t
typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht
+typeof (Undef t) = t
+
+idxToInt :: Idx env a -> Int
+idxToInt Zero = 0
+idxToInt (Succ i) = idxToInt i + 1
+
+shtToInt :: ShapeType sh -> Int
+shtToInt STZ = 0
+shtToInt (STC sht) = shtToInt sht + 1
data Has c a where
Has :: c a => Has c a
diff --git a/Count.hs b/Count.hs
new file mode 100644
index 0000000..d9fe661
--- /dev/null
+++ b/Count.hs
@@ -0,0 +1,94 @@
+{-# LANGUAGE GADTs #-}
+module Count where
+
+import Data.GADT.Compare
+import Data.Type.Equality
+
+import AST
+
+
+data Layout = LyLeaf Integer | LyPair Layout Layout
+ deriving (Show)
+
+instance Semigroup Layout where
+ LyLeaf a <> LyLeaf b = LyLeaf (a + b)
+ l@(LyLeaf _) <> LyPair l1 l2 = LyPair (l <> l1) (l <> l2)
+ LyPair l1 l2 <> l@(LyLeaf _) = LyPair (l1 <> l) (l2 <> l)
+ LyPair l1 l2 <> LyPair l3 l4 = LyPair (l1 <> l3) (l2 <> l4)
+
+instance Monoid Layout where
+ mempty = LyLeaf 0
+
+-- | Returns the maximum usage count of any component of the layout.
+lycontract :: Layout -> Integer
+lycontract (LyLeaf n) = n
+lycontract (LyPair ly1 ly2) = max (lycontract ly1) (lycontract ly2)
+
+-- | Given the usage count of a pair, return the usage count of its first component
+lyfst :: Layout -> Layout
+lyfst (LyLeaf n) = LyLeaf n
+lyfst (LyPair ly _) = ly
+
+-- | Given the usage count of a pair, return the usage count of its second component
+lysnd :: Layout -> Layout
+lysnd (LyLeaf n) = LyLeaf n
+lysnd (LyPair _ ly) = ly
+
+-- | Returns the usage count of an expression given its usage counts in two
+-- mutually exclusive program branches.
+lymax :: Layout -> Layout -> Layout
+lymax (LyLeaf n) (LyLeaf m) = LyLeaf (max n m)
+lymax (LyPair a1 a2) (LyPair b1 b2) = LyPair (lymax a1 b1) (lymax a2 b2)
+lymax (LyLeaf n) ly@LyPair{} = lymax (LyPair (LyLeaf n) (LyLeaf n)) ly
+lymax ly@LyPair{} (LyLeaf n) = lymax ly (LyPair (LyLeaf n) (LyLeaf n))
+
+-- | Count the uses of a variable in an expression
+usesOf :: Idx env t -> Exp env a -> Integer
+usesOf x e = lycontract (usesOf' PathStart x e)
+
+-- Path upwards in the tuple, starting at Start.
+data Path part t where
+ PathFst :: Path part t -> Path part (t, a)
+ PathSnd :: Path part t -> Path part (a, t)
+ PathStart :: Path part part
+
+-- | Count the uses of the components of a variable in an expression
+usesOf' :: Path large a -> Idx env t -> Exp env a -> Layout
+usesOf' _ i (App a b) = usesOf' PathStart i a <> usesOf' PathStart i b
+usesOf' _ i (Lam _ e) = usesOf' PathStart (Succ i) e
+usesOf' pt i (Var _ i') =
+ let leaf = case geq i i' of Just Refl -> LyLeaf 1
+ Nothing -> mempty
+ build :: Path a large -> Layout -> Layout
+ build (PathFst pt') ly = build pt' (LyPair ly mempty)
+ build (PathSnd pt') ly = build pt' (LyPair mempty ly)
+ build PathStart ly = ly
+ in build pt leaf
+usesOf' _ i (Let a b) = usesOf' PathStart i a <> usesOf' PathStart (Succ i) b
+usesOf' _ _ (Lit _) = mempty
+usesOf' pt i (Cond p l r) =
+ usesOf' PathStart i p <> lymax (usesOf' pt i l) (usesOf' pt i r)
+usesOf' _ _ (Const _) = mempty
+usesOf' _ i (Pair a b) = usesOf' PathStart i a <> usesOf' PathStart i b
+usesOf' pt i (Fst e) = usesOf' (PathFst pt) i e
+usesOf' pt i (Snd e) = usesOf' (PathSnd pt) i e
+-- TODO: Huge hack that allows arbitrary computational complexity increase.
+-- The code current counts usages in the lambdas in Build and Ifold _once_,
+-- whereas those usages are actually evaluated many times. The correct fix
+-- would be to analyse whether the accessed indexes of the counted variable are
+-- all disjoint, but that is hard.
+usesOf' _ i (Build _ she fe) = usesOf' PathStart i she <> usesOf' PathStart i fe
+-- usesOf' _ i (Build STZ she fe) = usesOf' PathStart i she <> usesOf' PathStart i fe
+-- usesOf' _ i (Build _ she fe) = usesOf' PathStart i she <> mul2 (usesOf' PathStart i fe)
+usesOf' _ i (Ifold _ fe e0 she) =
+ usesOf' PathStart i fe <> usesOf' PathStart i e0 <> usesOf' PathStart i she
+-- usesOf' _ i (Ifold STZ fe e0 she) =
+-- usesOf' PathStart i fe <> usesOf' PathStart i e0 <> usesOf' PathStart i she
+-- usesOf' _ i (Ifold _ fe e0 she) =
+-- mul2 (usesOf' PathStart i fe) <> usesOf' PathStart i e0 <> usesOf' PathStart i she
+usesOf' _ i (Index a b) = usesOf' PathStart i a <> usesOf' PathStart i b
+usesOf' _ i (Shape e) = usesOf' PathStart i e
+usesOf' _ _ (Undef _) = mempty
+
+-- mul2 :: Semigroup m => Layout m -> Layout m
+-- mul2 ly = ly <> ly
diff --git a/Eval.hs b/Eval.hs
new file mode 100644
index 0000000..19265bc
--- /dev/null
+++ b/Eval.hs
@@ -0,0 +1,128 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+module Eval (
+ eval,
+) where
+
+import Data.List (foldl')
+import qualified Data.Vector as V
+
+import AST
+
+
+data Val env where
+ Top :: Val '[]
+ Push :: Val env -> a -> Val (a ': env)
+
+prj :: Val env -> Idx env a -> a
+prj (Push _ x) Zero = x
+prj (Push env _) (Succ i) = prj env i
+
+eval :: Exp '[] a -> a
+eval = eval' Top
+
+eval' :: forall env a. Val env -> Exp env a -> a
+eval' env = \case
+ App f a -> rec f (rec a)
+ Lam _ e -> \x -> eval' (Push env x) e
+ Var _ i -> prj env i
+ Let a e -> eval' (Push env (rec a)) e
+ Lit l -> evalL l
+ Cond c a b -> if rec c then rec a else rec b
+ Const c -> evalC c
+ Pair a b -> (rec a, rec b)
+ Fst e -> fst (rec e)
+ Snd e -> snd (rec e)
+ Build sht she fe ->
+ let TFun _ ty = typeof fe
+ in build ty sht (rec she) (rec fe)
+ Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she)
+ Index a i -> index (rec a) (rec i)
+ Shape a -> shape (rec a)
+ Undef t -> error ("eval: Undef of type " ++ show t)
+ where rec :: Exp env t -> t
+ rec = eval' env
+
+evalL :: Literal a -> a
+evalL (LInt n) = n
+evalL (LBool b) = b
+evalL (LDouble d) = d
+evalL (LArray a) = a
+evalL (LShape Z) = ()
+evalL (LShape sh) = unshape sh
+evalL LNil = ()
+evalL (LPair a b) = (evalL a, evalL b)
+
+evalC :: Constant a -> a
+evalC CAddI = uncurry (+)
+evalC CSubI = uncurry (-)
+evalC CMulI = uncurry (*)
+evalC CDivI = uncurry div
+evalC CAddF = uncurry (+)
+evalC CSubF = uncurry (-)
+evalC CMulF = uncurry (*)
+evalC CDivF = uncurry (/)
+evalC CLog = log
+evalC CExp = exp
+evalC CtoF = fromIntegral
+evalC CRound = round
+evalC CLtI = uncurry (<)
+evalC CLeI = uncurry (<=)
+evalC CLtF = uncurry (<)
+evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==)
+ | otherwise = error ("eval: Cannot Eq compare values of type " ++ show t)
+evalC CAnd = uncurry (&&)
+evalC COr = uncurry (||)
+evalC CNot = not
+
+build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a
+build ty sht sh f =
+ let sh' = toshape sht sh
+ in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i)))
+
+ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a
+ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh))
+
+index :: Array sh a -> sh -> a
+index (Array sh _ v) idx = v V.! tolinear sh idx
+
+shape :: Array sh a -> sh
+shape (Array sh _ _) = unshape sh
+
+enumshape :: Shape sh -> [sh]
+enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh))
+ where
+ next :: Shape sh -> sh -> sh
+ next Z () = ()
+ next (sh' :. n) (idx, i)
+ | i < n = (idx, i + 1)
+ | otherwise = (next sh' idx, 0)
+
+ zeroshape :: Shape sh -> sh
+ zeroshape Z = ()
+ zeroshape (sh' :. _) = (zeroshape sh', 0)
+
+unshape :: Shape sh -> sh
+unshape Z = ()
+unshape (sh :. n) = (unshape sh, n)
+
+toshape :: ShapeType sh -> sh -> Shape sh
+toshape STZ () = Z
+toshape (STC sht) (sh, n) = toshape sht sh :. n
+
+tolinear :: Shape sh -> sh -> Int
+tolinear Z () = 0
+tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i
+
+fromlinear :: Shape sh -> Int -> sh
+fromlinear Z _ = ()
+fromlinear (sh :. n) i =
+ let (q, r) = i `divMod` n
+ in (fromlinear sh q, r)
+
+shapesize :: Shape sh -> Int
+shapesize Z = 1
+shapesize (sh :. n) = n * shapesize sh
diff --git a/Examples.hs b/Examples.hs
index 9d9cda7..0719f1f 100644
--- a/Examples.hs
+++ b/Examples.hs
@@ -4,11 +4,11 @@ import AST
import qualified Language as L
-sumSq :: Exp env (Array (Int, ()) Double -> Double)
+sumSq :: Exp env (Array L.DIM1 Double -> Double)
sumSq = Lam (TArray (STC STZ) TDouble)
(L.sum (App mapSq (Var (TArray (STC STZ) TDouble) Zero)))
-mapSq :: Exp env (Array (Int, ()) Double -> Array (Int, ()) Double)
+mapSq :: Exp env (Array L.DIM1 Double -> Array L.DIM1 Double)
mapSq =
Lam (TArray (STC STZ) TDouble)
(L.map (Lam TDouble
@@ -16,13 +16,16 @@ mapSq =
(Pair (Var TDouble Zero) (Var TDouble Zero))))
(Var (TArray (STC STZ) TDouble) Zero))
-mapSqIota :: Exp env (Array (Int, ()) Double)
+mapSqIota :: Exp env (Array L.DIM1 Double)
mapSqIota =
L.map (Lam TDouble
(App (Const CMulF)
(Pair (Var TDouble Zero) (Var TDouble Zero))))
(Build (STC STZ)
- (Pair (Lit (LInt 5)) (Lit LNil))
- (Lam (TPair TInt TNil)
+ (Pair (Lit LNil) (Lit (LInt 5)))
+ (Lam L.infer
(App (Const CtoF)
- (Fst (Var (TPair TInt TNil) Zero)))))
+ (Snd (L.var Zero)))))
+
+transpose2 :: Exp env (Array L.DIM2 Double -> Array L.DIM2 Double)
+transpose2 = Lam L.infer (App (L.transpose L.infer) (App (L.transpose L.infer) (L.var Zero)))
diff --git a/Language.hs b/Language.hs
index e16cf7c..8ab6199 100644
--- a/Language.hs
+++ b/Language.hs
@@ -1,4 +1,6 @@
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
{-| This module is intended to be imported qualified, perhaps as @L@. -}
module Language where
@@ -7,6 +9,28 @@ import AST
import Sink
+-- Convention: matrices are represented in row-major: (((), y), x)
+type DIM0 = ()
+type DIM1 = (DIM0, Int)
+type DIM2 = (DIM1, Int)
+type DIM3 = (DIM2, Int)
+
+class InferType a where infer :: Type a
+instance InferType Int where infer = TInt
+instance InferType Bool where infer = TBool
+instance InferType Double where infer = TDouble
+instance (InferType a, InferShapeType sh) => InferType (Array sh a) where infer = TArray inferST infer
+instance InferType () where infer = TNil
+instance (InferType a, InferType b) => InferType (a, b) where infer = TPair infer infer
+instance (InferType a, InferType b) => InferType (a -> b) where infer = TFun infer infer
+
+class InferShapeType sh where inferST :: ShapeType sh
+instance InferShapeType () where inferST = STZ
+instance InferShapeType sh => InferShapeType (sh, Int) where inferST = STC inferST
+
+var :: InferType a => Idx env a -> Exp env a
+var = Var infer
+
map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b)
map f e =
let ty@(TArray sht _) = typeof e
@@ -18,17 +42,16 @@ map f e =
(Index (Var ty (Succ Zero))
(Var sht' Zero)))))
-sum :: Exp env (Array (Int, ()) Double) -> Exp env Double
+sum :: Exp env (Array DIM1 Double) -> Exp env Double
sum e =
- let ty@(TArray sht _) = typeof e
- in Let e
- (Ifold sht
- (Lam (TPair TDouble (TPair TInt TNil))
- (App (Const CAddF) (Pair
- (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero))
- (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero))))))
- (Lit (LDouble 0))
- (Shape (Var ty Zero)))
+ Let e
+ (Ifold inferST
+ (Lam (TPair TDouble (TPair TNil TInt))
+ (App (Const CAddF) (Pair
+ (Fst (var Zero))
+ (Index (var (Succ Zero)) (Snd (var Zero))))))
+ (Lit (LDouble 0))
+ (Shape (var Zero)))
-- | The two input arrays are assumed to be the same size.
zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b))
@@ -50,3 +73,38 @@ oneHot sht sh idx =
(Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx)))
(Lit (LDouble 1))
(Lit (LDouble 0))))
+
+transpose :: Type a -> Exp env (Array DIM2 a -> Array DIM2 a)
+transpose ty =
+ Lam (TArray inferST ty)
+ (Build inferST (Shape (Var (TArray inferST ty) Zero))
+ (Lam infer (Index (Var (TArray inferST ty) (Succ Zero)) (Var infer Zero))))
+
+eye :: Exp env (Int -> Array DIM2 Double)
+eye =
+ Lam infer
+ (Build inferST (Pair (Pair (Lit LNil) (var Zero)) (var Zero))
+ (Lam infer
+ (Cond (App (Const (CEq infer)) (Pair (Snd (var Zero)) (Snd (Fst (var Zero)))))
+ (Lit (LDouble 1))
+ (Lit (LDouble 0)))))
+
+length :: Type a -> Exp env (Array DIM1 a -> Int)
+length ty = Lam (TArray inferST ty)
+ (Snd (Shape (Var (TArray inferST ty) Zero)))
+
+vmmul :: Exp env (Array DIM1 Double -> Array DIM2 Double -> Array DIM1 Double)
+vmmul =
+ Lam infer $ Lam infer $
+ Build inferST
+ (Pair (Lit LNil) (Snd (Shape (var Zero))))
+ (Lam infer $
+ Ifold inferST
+ (Lam infer $
+ App (Const CAddF) (Pair
+ (Fst (var Zero))
+ (App (Const CMulF) (Pair
+ (Index (var (Succ (Succ (Succ Zero)))) (Snd (var Zero)))
+ (Index (var (Succ (Succ Zero))) (Pair (Pair (Lit LNil) (Snd (Snd (var Zero)))) (Snd (var (Succ Zero)))))))))
+ (Lit (LDouble 0))
+ (Shape (var (Succ (Succ Zero)))))
diff --git a/Pretty.hs b/Pretty.hs
new file mode 100644
index 0000000..d63e3ce
--- /dev/null
+++ b/Pretty.hs
@@ -0,0 +1,229 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeOperators #-}
+module Pretty (
+ prettyExp,
+ pprintExp,
+) where
+
+import Data.Bifunctor
+import Prettyprinter
+import Prettyprinter.Render.String
+
+import AST
+
+
+newtype IdGen a = IdGen { runIdGen :: Int -> (a, Int) }
+instance Functor IdGen where
+ fmap f (IdGen g) = IdGen (first f . g)
+instance Applicative IdGen where
+ pure x = IdGen (x,)
+ IdGen f <*> IdGen g = IdGen (\i -> let (f', j) = f i in first f' (g j))
+instance Monad IdGen where
+ IdGen f >>= g = IdGen (\i -> let (x, j) = f i in runIdGen (g x) j)
+
+evalIdGen :: Int -> IdGen a -> a
+evalIdGen i = fst . ($ i) . runIdGen
+
+genId :: IdGen Int
+genId = IdGen (\i -> (i, i + 1))
+
+genName :: IdGen String
+genName = ('x' :) . show <$> genId
+
+data PEnv env where
+ Top :: PEnv env
+ PCons :: String -> PEnv env -> PEnv (a ': env)
+
+prettyExp :: Exp env a -> String
+prettyExp e = renderString (layoutSmart opts (evalIdGen 1 (pExp definfo 0 Top e)))
+ where opts = LayoutOptions (AvailablePerLine 120 0.7)
+
+pprintExp :: Exp env a -> IO ()
+pprintExp = putStrLn . prettyExp
+
+data Info = Info
+ { infoLamTypeSig :: Bool }
+ deriving (Show)
+
+definfo :: Info
+definfo = Info True
+
+pExp :: forall env a x. Info -> Int -> PEnv env -> Exp env a -> IdGen (Doc x)
+pExp thisinfo d env = \case
+ App (Const CAddF) (Pair a b) -> do
+ a' <- pExp definfo 7 env a
+ b' <- pExp definfo 7 env b
+ return (flatAlt (pParen (d > 10) $ pretty "AddF" <+> align (vsep [a', b']))
+ (pParen (d > 6) $ hsep [a', pretty "+", b']))
+
+ App (Const CMulF) (Pair a b) -> do
+ a' <- pExp definfo 8 env a
+ b' <- pExp definfo 8 env b
+ return (flatAlt (pParen (d > 10) $ pretty "MulF" <+> align (vsep [a', b']))
+ (pParen (d > 7) $ hsep [a', pretty "*", b']))
+
+ e@(App _ _) -> do
+ let collectAppsRev :: Exp env t -> IdGen (Doc x', [Doc x'])
+ collectAppsRev (App f a) = do
+ a' <- pExp definfo 11 env a
+ rest <- collectAppsRev f
+ return (fmap (a' :) rest)
+ collectAppsRev f = (,[]) <$> pExp definfo 11 env f
+ (func, rhss) <- collectAppsRev e
+ return (pParen (d > 10) $ func <+> align (sep (reverse rhss)))
+
+ Lam t e -> do
+ name <- genName
+ let prefix | infoLamTypeSig thisinfo =
+ pretty ("\\(" ++ name ++ " :: " ++ showType 0 t ") ->")
+ | otherwise =
+ pretty ("\\" ++ name ++ " ->")
+ body <- pExp definfo 0 (PCons name env) e
+ return (pParen (d > 0) $ nest 2 (sep [prefix, body]))
+
+ Var t i ->
+ case (env, i) of
+ (Top, _) -> return (pretty ("xUP_" ++ show (idxToInt i)))
+ (PCons name _, Zero) -> return (pretty name)
+ (PCons _ env', Succ i') -> pExp definfo d env' (Var t i')
+
+ e@(Let _ _) -> do
+ let collectLets :: PEnv env' -> Exp env' t -> IdGen (Doc x', [Doc x'])
+ collectLets env' (Let rhs body) = do
+ name <- genName
+ rhs' <- (pretty (name ++ " = ") <>) . group <$> pExp definfo 0 env' rhs
+ rest <- collectLets (PCons name env') body
+ return (fmap (rhs' :) rest)
+ collectLets env' f = (,[]) <$> pExp definfo 0 env' f
+ (core, rhss) <- collectLets env e
+ return (pParen (d > 0) $
+ align (vsep [pretty "let" <+> align (vsep rhss)
+ ,pretty "in" <+> group core]))
+
+ Lit l -> return (pretty (showLit d l ""))
+
+ Cond e1 e2 e3 -> do
+ e1' <- pExp definfo 11 env e1
+ e2' <- pExp definfo 11 env e2
+ e3' <- pExp definfo 11 env e3
+ return (flatAlt (pParen (d > 10) $ pretty "cond" <+> align (vsep [e1', e2', e3']))
+ (pParen (d > 0) $ hsep [e1', pretty "?", e2', pretty ":", e3']))
+
+ Const c -> return (pretty (showConst c))
+
+ Pair e1 e2 -> do
+ e1' <- pExp definfo 0 env e1
+ e2' <- pExp definfo 0 env e2
+ return (tupled [e1', e2'])
+
+ Fst e -> do
+ e' <- pExp definfo 11 env e
+ return (pParen (d > 10) $ pretty "fst" <+> e')
+
+ Snd e -> do
+ e' <- pExp definfo 11 env e
+ return (pParen (d > 10) $ pretty "snd" <+> e')
+
+ Build sht e1 e2 -> do
+ e1' <- pExp definfo 11 env e1
+ e2' <- pExp definfo{infoLamTypeSig=False} 11 env e2
+ return (pParen (d > 10) $
+ pretty "build" <+> align (sep
+ [pretty ("DIM" <> show (shtToInt sht)), e1', e2']))
+
+ Ifold sht e1 e2 e3 -> do
+ e1' <- pExp (definfo{infoLamTypeSig=False}) 11 env e1
+ e2' <- pExp definfo 11 env e2
+ e3' <- pExp definfo 11 env e3
+ return (pParen (d > 10) $
+ pretty "ifold" <+> align (sep
+ [pretty ("DIM" <> show (shtToInt sht)), e1', e2', e3']))
+
+ Index e1 e2 -> do
+ e1' <- pExp definfo 11 env e1
+ e2' <- pExp definfo 11 env e2
+ return (pParen (d > 10) $
+ flatAlt (pretty "index" <+> align (sep [e1', e2']))
+ (hsep [e1', pretty "!", e2']))
+
+ Shape e -> do
+ e' <- pExp definfo 11 env e
+ return (pParen (d > 10) $ pretty "shape" <+> e')
+
+ Undef t -> return (pParen (d > 0) $ pretty ("UNDEF :: " ++ showType 0 t ""))
+
+pParen :: Bool -> Doc x -> Doc x
+pParen True = parens
+pParen False = id
+
+showLit :: Int -> Literal a -> ShowS
+showLit _ (LInt i) = shows i
+showLit _ (LBool b) = shows b
+showLit _ (LDouble d) = shows d
+showLit d (LArray (Array sh t v))
+ | Just Has <- typeHasShow t
+ = showParen (d > 0) $
+ shows v . showString " :: Array " . showShape 11 sh . showString " " . showType 11 t
+ | otherwise
+ = showParen (d > 0) $
+ showString "[{noshow}] :: Array " . showShape 11 sh . showString " " . showType 11 t
+showLit d (LShape sh) = showShape d sh
+showLit _ LNil = showString "()"
+showLit _ (LPair a b) =
+ showString "(" . showLit 0 a . showString ", " . showLit 0 b . showString ")"
+
+showConst :: Constant a -> String
+showConst CAddI = "AddI"
+showConst CSubI = "SubI"
+showConst CMulI = "MulI"
+showConst CDivI = "DivI"
+showConst CAddF = "AddF"
+showConst CSubF = "SubF"
+showConst CMulF = "MulF"
+showConst CDivF = "DivF"
+showConst CLog = "Log"
+showConst CExp = "Exp"
+showConst CtoF = "ToF"
+showConst CRound = "Round"
+showConst CLtI = "LtI"
+showConst CLeI = "LeI"
+showConst CLtF = "LtF"
+showConst (CEq _) = "Eq"
+showConst CAnd = "And"
+showConst COr = "Or"
+showConst CNot = "Not"
+
+showShape :: Int -> Shape sh -> ShowS
+showShape _ Z = showString "Z"
+showShape d (sh :. n) = showParen (d > 10) $
+ showShape 10 sh . showString ":" . shows n
+
+showType :: Int -> Type a -> ShowS
+showType _ TInt = showString "Int"
+showType _ TBool = showString "Bool"
+showType _ TDouble = showString "Double"
+showType _ (TArray sht t) =
+ let n = shtToInt sht
+ in showString (replicate n '[') . showType 0 t . showString (replicate n ']')
+showType _ TNil = showString "()"
+showType _ (TPair a b) =
+ showString "(" . showType 0 a . showString ", " . showType 0 b . showString ")"
+showType d (TFun a b) = showParen (d > 10) $
+ showType 11 a . showString " -> " . showType 10 b
+
+-- showTypeShort :: Int -> Type a -> ShowS
+-- showTypeShort _ TInt = showString "i"
+-- showTypeShort _ TBool = showString "b"
+-- showTypeShort _ TDouble = showString "d"
+-- showTypeShort _ (TArray sht t) =
+-- let n = shtToInt sht
+-- in showString (replicate n '[') . showTypeShort 0 t . showString (replicate n ']')
+-- showTypeShort _ TNil = showString "."
+-- showTypeShort _ (TPair a b) =
+-- showString "(" . showTypeShort 11 a . showTypeShort 11 b . showString ")"
+-- showTypeShort d (TFun a b) = showParen (d > 10) $
+-- showTypeShort 11 a . showString " -> " . showTypeShort 10 b
diff --git a/Repl.hs b/Repl.hs
index 85fe7be..9e7cef2 100644
--- a/Repl.hs
+++ b/Repl.hs
@@ -7,5 +7,7 @@ import AD
import Examples
import Gradient
import qualified Language as L
+import Language (InferType(..), InferShapeType(..), DIM0, DIM1, DIM2, DIM3, var)
+import Pretty
import Simplify
import Sink
diff --git a/Simplify.hs b/Simplify.hs
index 9ceaef9..e95043d 100644
--- a/Simplify.hs
+++ b/Simplify.hs
@@ -2,12 +2,15 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Simplify (
simplify,
simplifyFix,
+ simbeta, simpair, simindex, simifold1,
+ simfix, SimList(..),
) where
import Data.Bifunctor
@@ -16,26 +19,51 @@ import qualified Data.Kind as Kind
import Data.List (find)
import Data.Type.Equality
+import Debug.Trace
+
import AST
+import Count
import Sink
+data Bound a = Inclusive a | Exclusive a
+ deriving (Show)
+
+instance Functor Bound where
+ fmap f (Inclusive x) = Inclusive (f x)
+ fmap f (Exclusive x) = Exclusive (f x)
+
data family Info (env :: [Kind.Type]) a
--- data instance Info env Int = InfoInt
--- data instance Info env Bool = InfoBool
--- data instance Info env Double = InfoDouble
+-- | Lower bound (inclusive), upper bound (inclusive/exclusive)
+data instance Info env Int = InfoInt (Maybe (Exp env Int)) (Maybe (Bound (Exp env Int)))
data instance Info env (Array sh t) = InfoArray (Exp env sh)
--- data instance Info env () = InfoNil
-data instance Info env (a, b) = InfoPair (Info env a) (Info env b)
--- data instance Info env (a -> b) = InfoFun
+data instance Info env () = InfoNil
+data instance Info env (a, b) = InfoPair (Maybe (Info env a)) (Maybe (Info env b))
data IEnv env where
ITop :: IEnv env
ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env)
+showsInfo :: Int -> Type a -> Info env a -> ShowS
+showsInfo d TInt (InfoInt a b) = showParen (d > 10) $
+ showString "InfoInt " . showsPrec 11 a . showString " " . showsPrec 11 b
+showsInfo d TArray{} (InfoArray a) = showParen (d > 10) $
+ showString "InfoArray " . showsPrec 11 a
+showsInfo _ TNil InfoNil = showString "InfoNil"
+showsInfo d (TPair t1 t2) (InfoPair a b) = showParen (d > 10) $
+ showString "InfoPair " . showsInfo' 11 t1 a . showString " " . showsInfo' 11 t2 b
+showsInfo _ _ _ = error "showsInfo: No definition"
+
+showsInfo' :: Int -> Type a -> Maybe (Info env a) -> ShowS
+showsInfo' _ _ Nothing = showString "Nothing"
+showsInfo' d t (Just x) = showParen (d > 10) $
+ showString "Just " . showsInfo 11 t x
+
sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a
+sinkInfo1 TInt (InfoInt a b) = InfoInt (sinkExp1 <$> a) (fmap sinkExp1 <$> b)
sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e)
-sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b)
+sinkInfo1 TNil InfoNil = InfoNil
+sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 <$> a) (sinkInfo1 t2 <$> b)
sinkInfo1 _ _ = error "Unknown info in sinkInfo1"
iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a)
@@ -64,6 +92,8 @@ simplify' env = \case
let (arg', info) = simplify' env arg
env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env
in (simplifyLet arg' (fst (simplify' env' e)), Nothing)
+ Lit (LInt n) -> (Lit (LInt n), Just (InfoInt (Just (Lit (LInt n)))
+ (Just (Inclusive (Lit (LInt n))))))
Lit l -> (Lit l, Nothing)
Cond a b c ->
(Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
@@ -71,27 +101,55 @@ simplify' env = \case
Pair a b ->
let (a', ia) = simplify' env a
(b', ib) = simplify' env b
- in (Pair a' b', InfoPair <$> ia <*> ib)
- Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e)
- Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e)
+ in (simplifyPair a' b', Just (InfoPair ia ib))
+ Fst e -> bimap simplifyFst (>>= (\(InfoPair i _) -> i)) (simplify' env e)
+ Snd e -> bimap simplifySnd (>>= (\(InfoPair _ i) -> i)) (simplify' env e)
+ Build sht a (Lam shty fe) ->
+ let a' = fst (simplify' env a)
+ env' = ICons shty (Just (shapeBoundInfo sht (sinkExp1 a'))) env
+ in (Build sht a' (Lam shty (fst (simplify' env' fe))), Just (InfoArray a'))
Build sht a b ->
let a' = fst (simplify' env a)
in (Build sht a' (fst (simplify' env b)), Just (InfoArray a'))
- Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
+ Ifold sht a b c ->
+ (simplifyIfold env sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing)
Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing)
Shape e ->
case simplify' env e of
(_, Just (InfoArray she)) -> (she, Nothing)
(e', _) -> (Shape e', Nothing)
+ Undef t -> (Undef t, Nothing)
+
+shapeBoundInfo :: ShapeType sh -> Exp env sh -> Info env sh
+shapeBoundInfo STZ _ = InfoNil
+shapeBoundInfo (STC sht) she =
+ InfoPair (Just (shapeBoundInfo sht (Fst she)))
+ (Just (InfoInt (Just (Lit (LInt 0))) (Just (Exclusive (Snd she)))))
simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b
simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b))
+simplifyApp (Const CAddI) (Pair a (Lit (LInt 0))) = a
+simplifyApp (Const CAddI) (Pair (Lit (LInt 0)) a) = a
+-- simplifyApp (Const CAddI) (Pair a b) | Just Refl <- geq a b =
+-- simplifyApp (Const CMulI) (Pair (Lit (LInt 2)) a)
simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b))
simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b))
+simplifyApp (Const CMulI) (Pair a (Lit (LInt 1))) = a
+simplifyApp (Const CMulI) (Pair (Lit (LInt 1)) a) = a
+simplifyApp (Const CMulI) (Pair _ (Lit (LInt 0))) = Lit (LInt 0)
+simplifyApp (Const CMulI) (Pair (Lit (LInt 0)) _) = Lit (LInt 0)
simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b))
simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b))
+simplifyApp (Const CAddF) (Pair a (Lit (LDouble 0))) = a
+simplifyApp (Const CAddF) (Pair (Lit (LDouble 0)) a) = a
+-- simplifyApp (Const CAddF) (Pair a b) | Just Refl <- geq a b =
+-- simplifyApp (Const CMulF) (Pair (Lit (LDouble 2)) a)
simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b))
simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b))
+simplifyApp (Const CMulF) (Pair a (Lit (LDouble 1))) = a
+simplifyApp (Const CMulF) (Pair (Lit (LDouble 1)) a) = a
+simplifyApp (Const CMulF) (Pair _ (Lit (LDouble 0))) = Lit (LDouble 0)
+simplifyApp (Const CMulF) (Pair (Lit (LDouble 0)) _) = Lit (LDouble 0)
simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b))
simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a))
simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a))
@@ -107,15 +165,17 @@ simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a |
simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a))
simplifyApp (Lam _ e) arg
- | isDuplicable arg || countOcc Zero e <= 1
+ | isDuplicable arg || usesOf Zero e <= 1
= simplify (subst arg e)
simplifyApp (Lam _ e) arg = simplifyLet arg e
+simplifyApp f (Cond c a b) = simplify $ Cond c (App f a) (App f b)
+
simplifyApp a b = App a b
simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b
simplifyLet arg e
- | isDuplicable arg || countOcc Zero e <= 1
+ | isDuplicable arg || usesOf Zero e <= 1
= simplify (subst arg e)
simplifyLet (Pair a b) e =
simplifyLet a $
@@ -124,24 +184,89 @@ simplifyLet (Pair a b) e =
(Var (typeof b) Zero)
Succ i -> Var t (Succ (Succ i)))
e
-simplifyLet (Cond c a b) e
- | isDuplicable a && isDuplicable b
- = simplifyLet c $
- (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b)
- Succ i -> Var t (Succ i))
- e)
-simplifyLet a b = Let (simplify a) (simplify b)
+-- simplifyLet (Cond c a b) e
+-- | isDuplicable a && isDuplicable b
+-- = simplifyLet c $
+-- (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b)
+-- Succ i -> Var t (Succ i))
+-- e)
+simplifyLet (Cond c a b) e = simplify $ Cond c (Let a e) (Let b e)
+simplifyLet a b = Let a b
+
+simplifyPair :: Exp env a -> Exp env b -> Exp env (a, b)
+simplifyPair (Cond c a b) d = simplify $ Cond c (Pair a d) (Pair b d)
+simplifyPair d (Cond c a b) = simplify $ Cond c (Pair d a) (Pair d b)
+simplifyPair a b = Pair a b
simplifyFst :: Exp env (a, b) -> Exp env a
simplifyFst (Pair e _) = e
simplifyFst (Let a e) = simplifyLet a (simplifyFst e)
+simplifyFst (Cond c a b) = simplify $ Cond c (Fst a) (Fst b)
simplifyFst e = Fst e
simplifySnd :: Exp env (a, b) -> Exp env b
simplifySnd (Pair _ e) = e
simplifySnd (Let a e) = simplifyLet a (simplifySnd e)
+simplifySnd (Cond c a b) = simplify $ Cond c (Snd a) (Snd b)
simplifySnd e = Snd e
+simplifyIfold :: IEnv env -> ShapeType sh -> Exp env ((a, sh) -> a) -> Exp env a -> Exp env sh -> Exp env a
+simplifyIfold env sht fe e0 she
+ | Just res <- splitIfold sht fe e0 she
+ = fst (simplify' env res)
+-- Given the following:
+-- ifold (\(a,i) -> if i == cmpref then val else a) _ she
+-- and given that we can prove that 0 <= cmpref < she and that 'val' is free,
+-- the whole fold can be replaced with 'val'.
+simplifyIfold env sht (Lam argty (Cond (App (Const (CEq _)) (Pair (Snd (Var _ Zero)) cmpref)) val (Fst (Var _ Zero)))) e0 she
+ | let env' = ICons argty (Just (InfoPair Nothing (Just (shapeBoundInfo sht (sinkExp1 she))))) env
+ , trace ("si: trying") True
+ , proveShapeBound env' CLeI sht (zeroShapeExp sht) cmpref
+ , trace ("si: prf1 = True") True
+ , proveShapeBound env' CLtI sht cmpref (sinkExp1 she)
+ , trace ("si: prf2 = True") True
+ , trace ("si: cmpref = " ++ show val) True
+ , usesOf Zero cmpref == 0
+ = simplifyLet (Pair e0 (subst (error "usesOf == 0 was wrong") cmpref)) val
+simplifyIfold _ sht fe e0 she = Ifold sht fe e0 she
+
+zeroShapeExp :: ShapeType sh -> Exp env sh
+zeroShapeExp STZ = Lit LNil
+zeroShapeExp (STC sht) = Pair (zeroShapeExp sht) (Lit (LInt 0))
+
+proveShapeBound :: IEnv env -> Constant ((Int, Int) -> Bool) -> ShapeType sh -> Exp env sh -> Exp env sh -> Bool
+proveShapeBound _ _ STZ _ _ = True
+proveShapeBound env cmpop (STC sht) e1 e2 =
+ let (_, info1) = simplify' env (Snd e1)
+ (_, info2) = simplify' env (Snd e2)
+ inclLo2 = case info2 of
+ Just (InfoInt (Just lo2) _) -> lo2
+ _ -> Snd e2 -- this is also an inclusive lower bound, after all
+ restresult = proveShapeBound env cmpop sht (Fst e1) (Fst e2)
+ in restresult && case (cmpop, info1) of
+ (CLeI, Just (InfoInt _ (Just (Inclusive hi1)))) ->
+ proveLe hi1 inclLo2
+ (CLtI, Just (InfoInt _ (Just (Exclusive hi1)))) ->
+ proveLe hi1 inclLo2
+ _ -> trace ("proveShapeBound: " ++ show cmpop ++ " (" ++ show e1 ++ ") (" ++ show e2 ++ ")") $
+ trace (" e1 = " ++ show e1) $
+ trace (" e2 = " ++ show e2) $
+ trace (" info1 = " ++ showsInfo' 0 TInt info1 "") $
+ trace (" info2 = " ++ showsInfo' 0 TInt info2 "") $
+ trace (" inclLo2 = " ++ show inclLo2) $
+ False
+
+proveLe :: Exp env Int -> Exp env Int -> Bool
+proveLe = \e1 e2 ->
+ let res = proveLe' e1 e2
+ in trace ("proveLe: '" ++ show e1 ++ "' <= '" ++ show e2 ++ "' -> " ++ show res)
+ res
+
+proveLe' :: Exp env Int -> Exp env Int -> Bool
+proveLe' e1 e2 | Just Refl <- geq e1 e2 = True
+proveLe' (Lit (LInt a)) (Lit (LInt b)) | a <= b = True
+proveLe' _ _ = False
+
simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a
simplifyIndex (Build _ _ f) e = simplifyApp f e
simplifyIndex a e = Index a e
@@ -162,24 +287,6 @@ isDuplicable (Fst e) = isDuplicable e
isDuplicable (Snd e) = isDuplicable e
isDuplicable _ = False
-countOcc :: Idx env t -> Exp env a -> Int
-countOcc i (App a b) = countOcc i a + countOcc i b
-countOcc i (Lam _ e) = countOcc (Succ i) e
-countOcc i (Var _ j)
- | Just Refl <- geq i j = 1
- | otherwise = 0
-countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b
-countOcc _ (Lit _) = 0
-countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c
-countOcc _ (Const _) = 0
-countOcc i (Pair a b) = countOcc i a + countOcc i b
-countOcc i (Fst e) = countOcc i e
-countOcc i (Snd e) = countOcc i e
-countOcc i (Build _ a b) = countOcc i a + countOcc i b
-countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c
-countOcc i (Index a b) = countOcc i a + countOcc i b
-countOcc i (Shape e) = countOcc i e
-
subst :: Exp env t -> Exp (t ': env) a -> Exp env a
subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e
@@ -201,3 +308,100 @@ subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b)
subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c)
subst' f (Index a b) = Index (subst' f a) (subst' f b)
subst' f (Shape e) = Shape (subst' f e)
+subst' _ (Undef t) = Undef t
+
+splitIfold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Maybe (Exp env s)
+splitIfold sht (Lam (TPair (TPair t1 t2) tidx) (Pair e1 e2)) e0 she
+ | let uses1 = usesOf' PathStart Zero e1
+ uses2 = usesOf' PathStart Zero e2
+ , lycontract (lysnd (lyfst uses1)) == 0
+ , lycontract (lyfst (lyfst uses2)) == 0
+ -- Substitute the argument in e1 and t2 to refer to just the used
+ -- components of their argument. To do this we reconstruct the original,
+ -- partially unused argument by putting an 'Undef' in the unused spot.
+ , let e1' = subst' (\t -> \case Zero ->
+ Pair (Pair (Fst (Var (TPair t1 tidx) Zero))
+ (Undef t2))
+ (Snd (Var (TPair t1 tidx) Zero))
+ Succ i -> Var t (Succ i))
+ e1
+ e2' = subst' (\t -> \case Zero ->
+ Pair (Pair (Undef t1)
+ (Fst (Var (TPair t2 tidx) Zero)))
+ (Snd (Var (TPair t2 tidx) Zero))
+ Succ i -> Var t (Succ i))
+ e2
+ = Just $
+ Let e0 $ Let (sinkExp1 she) $
+ Pair (Ifold sht (sinkExp2 (Lam (TPair t1 tidx) e1'))
+ (Fst (Var (TPair t1 t2) (Succ Zero)))
+ (Var tidx Zero))
+ (Ifold sht (sinkExp2 (Lam (TPair t2 tidx) e2'))
+ (Snd (Var (TPair t1 t2) (Succ Zero)))
+ (Var tidx Zero))
+splitIfold _ _ _ _ = Nothing
+
+simbeta :: Exp env a -> Exp env a
+simbeta = \case
+ App (Lam _ e) a
+ | isDuplicable a || usesOf Zero e <= 1
+ -> simbeta (subst a e)
+ | otherwise
+ -> Let (simbeta a) (simbeta e)
+ Let a e
+ | isDuplicable a || usesOf Zero e <= 1
+ -> simbeta (subst a e)
+ e -> simrecurse simbeta e
+
+simpair :: Exp env a -> Exp env a
+simpair = \case
+ Fst (Pair a _) -> simpair a
+ Snd (Pair _ b) -> simpair b
+ Fst (Let a b) -> Let (simpair a) (simpair (Fst b))
+ Snd (Let a b) -> Let (simpair a) (simpair (Snd b))
+ e -> simrecurse simpair e
+
+simindex :: Exp env a -> Exp env a
+simindex = \case
+ Index (Build _ _ f) e ->
+ App (simindex f) (simindex e)
+ e -> simrecurse simindex e
+
+simifold1 :: Exp env a -> Exp env a
+simifold1 = \case
+ Ifold sht fe e0 she
+ | Just res <- splitIfold sht (simifold1 fe) (simifold1 e0) (simifold1 she)
+ -> res
+ e -> simrecurse simifold1 e
+
+simrecurse :: (forall env' a'. Exp env' a' -> Exp env' a') -> Exp env a -> Exp env a
+simrecurse f = \case
+ App a b -> App (f a) (f b)
+ Lam t e -> Lam t (f e)
+ Var t i -> Var t i
+ Let a e -> Let (f a) (f e)
+ Lit l -> Lit l
+ Cond a b c -> Cond (f a) (f b) (f c)
+ Const c -> Const c
+ Pair a b -> Pair (f a) (f b)
+ Fst e -> Fst (f e)
+ Snd e -> Snd (f e)
+ Build sht a b -> Build sht (f a) (f b)
+ Ifold sht a b c -> Ifold sht (f a) (f b) (f c)
+ Index a b -> Index (f a) (f b)
+ Shape e -> Shape (f e)
+ Undef t -> Undef t
+
+infixr :|
+data SimList = (forall env' a'. Exp env' a' -> Exp env' a') :| SimList
+ | SimEnd
+
+simfix :: SimList -> Exp env a -> Exp env a
+simfix list = \e -> let e' = looponce list e
+ in case geq e e' of
+ Just Refl -> e'
+ Nothing -> simfix list e'
+ where
+ looponce :: SimList -> Exp env a -> Exp env a
+ looponce SimEnd e = e
+ looponce (f :| l) e = looponce l (f e)
diff --git a/Sink.hs b/Sink.hs
index 1368cb6..c258dc5 100644
--- a/Sink.hs
+++ b/Sink.hs
@@ -39,6 +39,7 @@ sinkExp w = \case
Ifold sht e1 e2 e3 -> Ifold sht (sinkExp w e1) (sinkExp w e2) (sinkExp w e3)
Index e1 e2 -> Index (sinkExp w e1) (sinkExp w e2)
Shape e -> Shape (sinkExp w e)
+ Undef t -> Undef t
sinkExp1 :: Exp env a -> Exp (t ': env) a
sinkExp1 = sinkExp (wSucc wId)
diff --git a/ftilde.cabal b/ftilde.cabal
index 1432547..04a5a2a 100644
--- a/ftilde.cabal
+++ b/ftilde.cabal
@@ -12,14 +12,18 @@ executable ftilde
other-modules:
AST
AD
+ Count
+ Eval
Examples
Gradient
Language
+ Pretty
Repl
Simplify
Sink
build-depends:
base >= 4.13 && < 4.15,
+ prettyprinter >= 1.7.0 && < 1.8,
vector,
some
hs-source-dirs: .