aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Accum.hs137
-rw-r--r--src/AST/Bindings.hs84
-rw-r--r--src/AST/Count.hs930
-rw-r--r--src/AST/Env.hs95
-rw-r--r--src/AST/Pretty.hs525
-rw-r--r--src/AST/Sparse.hs287
-rw-r--r--src/AST/Sparse/Types.hs107
-rw-r--r--src/AST/SplitLets.hs191
-rw-r--r--src/AST/Types.hs215
-rw-r--r--src/AST/UnMonoid.hs255
-rw-r--r--src/AST/Weaken.hs138
-rw-r--r--src/AST/Weaken/Auto.hs192
12 files changed, 0 insertions, 3156 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
deleted file mode 100644
index 988a450..0000000
--- a/src/AST/Accum.hs
+++ /dev/null
@@ -1,137 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeData #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE UndecidableInstances #-}
-module AST.Accum where
-
-import AST.Types
-import Data
-
-
-data AcPrj
- = APHere
- | APFst AcPrj
- | APSnd AcPrj
- | APLeft AcPrj
- | APRight AcPrj
- | APJust AcPrj
- | APArrIdx AcPrj
- | APArrSlice Nat
-
--- | @b@ is a small part of @a@, indicated by the projection @p@.
-data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
- SAPHere :: SAcPrj APHere a a
- SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
- SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
- SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
- SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
- SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
- SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
- -- TODO:
- -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
-deriving instance Show (SAcPrj p a b)
-
-type data AIDense = AID | AIS
-
-data SAIDense d where
- SAID :: SAIDense AID
- SAIS :: SAIDense AIS
-deriving instance Show (SAIDense d)
-
-type family AcIdx d p t where
- AcIdx d APHere t = TNil
- AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
- AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
- AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
- AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
- AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
- AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
- AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
- AcIdx AID (APArrIdx p) (TArr n a) =
- -- (index, recursive info)
- TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
- AcIdx AIS (APArrIdx p) (TArr n a) =
- -- ((index, shape info), recursive info)
- TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx AIS p a)
- -- AcIdx AID (APArrSlice m) (TArr n a) =
- -- -- index
- -- Tup (Replicate m TIx)
- -- AcIdx AIS (APArrSlice m) (TArr n a) =
- -- -- (index, array shape)
- -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
-
-type AcIdxD p t = AcIdx AID p t
-type AcIdxS p t = AcIdx AIS p t
-
-acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
-acPrjTy SAPHere t = t
-acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
-acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
-acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
-acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
-acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
-acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
-
-type family ZeroInfo t where
- ZeroInfo TNil = TNil
- ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
- ZeroInfo (TLEither a b) = TNil
- ZeroInfo (TMaybe a) = TNil
- ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
- ZeroInfo (TScal t) = TNil
-
-tZeroInfo :: SMTy t -> STy (ZeroInfo t)
-tZeroInfo SMTNil = STNil
-tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
-tZeroInfo (SMTLEither _ _) = STNil
-tZeroInfo (SMTMaybe _) = STNil
-tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
-tZeroInfo (SMTScal _) = STNil
-
--- | Info needed to create a zero-valued deep accumulator for a monoid type.
--- Should be constructable from a D1.
-type family DeepZeroInfo t where
- DeepZeroInfo TNil = TNil
- DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
- DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
- DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
- DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
- DeepZeroInfo (TScal t) = TNil
-
-tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
-tDeepZeroInfo SMTNil = STNil
-tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
-tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
-tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
-tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
-tDeepZeroInfo (SMTScal _) = STNil
-
--- -- | Additional info needed for accumulation. This is empty unless there is
--- -- sparsity in the monoid.
--- type family AccumInfo t where
--- AccumInfo TNil = TNil
--- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
--- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
--- AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
--- AccumInfo (TArr n t) = TArr n (AccumInfo t)
--- AccumInfo (TScal t) = TNil
-
--- type family PrimalInfo t where
--- PrimalInfo TNil = TNil
--- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
--- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
--- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
--- PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
--- PrimalInfo (TScal t) = TNil
-
--- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
--- tPrimalInfo SMTNil = STNil
--- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
--- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
--- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
--- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
--- tPrimalInfo (SMTScal _) = STNil
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
deleted file mode 100644
index 463586a..0000000
--- a/src/AST/Bindings.hs
+++ /dev/null
@@ -1,84 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-
--- I want to bring various type variables in scope using type annotations in
--- patterns, but I don't want to have to mention all the other type parameters
--- of the types in question as well then. Partial type signatures (with '_') are
--- useful here.
-{-# LANGUAGE PartialTypeSignatures #-}
-{-# OPTIONS -Wno-partial-type-signatures #-}
-module AST.Bindings where
-
-import AST
-import AST.Env
-import Data
-import Lemmas
-
-
--- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'.
-data Bindings f env binds where
- BTop :: Bindings f env '[]
- BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)
-deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
-infixl `BPush`
-
-bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds)
-bpush b e = b `BPush` (typeOf e, e)
-infixl `bpush`
-
-mapBindings :: (forall env' t'. f env' t' -> g env' t')
- -> Bindings f env binds -> Bindings g env binds
-mapBindings _ BTop = BTop
-mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e)
-
-weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
- -> env1 :> env2
- -> Bindings f env1 binds
- -> (Bindings f env2 binds, Append binds env1 :> Append binds env2)
-weakenBindings _ w BTop = (BTop, w)
-weakenBindings wf w (BPush b (t, x)) =
- let (b', w') = weakenBindings wf w b
- in (BPush b' (t, wf w' x), WCopy w')
-
-weakenBindingsE :: env1 :> env2
- -> Bindings (Expr x) env1 binds
- -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2)
-weakenBindingsE = weakenBindings weakenExpr
-
-weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
-weakenOver SNil w = w
-weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
-
-sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env'
-sinkWithBindings BTop = WId
-sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
-
-bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1)
-bconcat b1 BTop = b1
-bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x))
- | Refl <- lemAppendAssoc @binds2C @binds1 @env
- = BPush (bconcat b1 b2) (t, x)
-
-bindingsBinds :: Bindings f env binds -> SList STy binds
-bindingsBinds BTop = SNil
-bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
-
-letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
-letBinds BTop = id
-letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
-
-collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
-collectBindings = \env -> fst . go env WId
- where
- go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
- go _ _ SETop = (BTop, WId)
- go (ty `SCons` env) w (SEYesR sub) =
- let (bs, w') = go env (WPop w) sub
- in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
- go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
deleted file mode 100644
index ac8634e..0000000
--- a/src/AST/Count.hs
+++ /dev/null
@@ -1,930 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingStrategies #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# LANGUAGE PatternSynonyms #-}
-module AST.Count where
-
-import Data.Functor.Product
-import Data.Some
-import Data.Type.Equality
-import GHC.Generics (Generic, Generically(..))
-
-import Array
-import AST
-import AST.Env
-import Data
-
-
--- | The monoid operation combines assuming that /both/ branches are taken.
-class Monoid a => Occurrence a where
- -- | One of the two branches is taken
- (<||>) :: a -> a -> a
- -- | This code is executed many times
- scaleMany :: a -> a
-
-
-data Count = Zero | One | Many
- deriving (Show, Eq, Ord)
-
-instance Semigroup Count where
- Zero <> n = n
- n <> Zero = n
- _ <> _ = Many
-instance Monoid Count where
- mempty = Zero
-instance Occurrence Count where
- (<||>) = max
- scaleMany Zero = Zero
- scaleMany _ = Many
-
-data Occ = Occ { _occLexical :: Count
- , _occRuntime :: Count }
- deriving (Eq, Generic)
- deriving (Semigroup, Monoid) via Generically Occ
-
-instance Show Occ where
- showsPrec d (Occ l r) = showParen (d > 10) $
- showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
-
-instance Occurrence Occ where
- Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2)
- scaleMany (Occ l c) = Occ l (scaleMany c)
-
-
-data Substruc t t' where
- -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below
- SsFull :: Substruc t t
- SsNone :: Substruc t TNil
- SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b')
- SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b')
- SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b')
- SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a')
- SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements
- SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a')
-
-pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t'
-pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2)
- where SsPair' = SsPair
-{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-}
-
-pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t'
-pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s)
- where SsArr' = SsArr
-{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-}
-
-instance Semigroup (Some (Substruc t)) where
- Some SsFull <> _ = Some SsFull
- _ <> Some SsFull = Some SsFull
- Some SsNone <> s = s
- s <> Some SsNone = s
- Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2)
- Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2)
- Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2)
- Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2)
- Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2)
- Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2)
-instance Monoid (Some (Substruc t)) where
- mempty = Some SsNone
-
-instance TestEquality (Substruc t) where
- testEquality SsFull s = isFull s
- testEquality s SsFull = sym <$> isFull s
- testEquality SsNone SsNone = Just Refl
- testEquality SsNone _ = Nothing
- testEquality _ SsNone = Nothing
- testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
- testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
- testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
- testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
- testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
- testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
-
-isFull :: Substruc t t' -> Maybe (t :~: t')
-isFull SsFull = Just Refl
-isFull SsNone = Nothing -- TODO: nil?
-isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
-isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
-isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
-isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
-isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
-isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
-
-applySubstruc :: Substruc t t' -> STy t -> STy t'
-applySubstruc SsFull t = t
-applySubstruc SsNone _ = STNil
-applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b)
-applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b)
-applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b)
-applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t)
-applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t)
-applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t)
-
-applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t'
-applySubstrucM SsFull t = t
-applySubstrucM SsNone _ = SMTNil
-applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b)
-applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b)
-applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t)
-applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t)
-applySubstrucM _ t = case t of {}
-
-data ExMap a b = ExMap (forall env. Ex env a -> Ex env b)
- | a ~ b => ExMapId
-
-fromExMap :: ExMap a b -> Ex env a -> Ex env b
-fromExMap (ExMap f) = f
-fromExMap ExMapId = id
-
-simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t'
-simplifySubstruc STNil SsNone = SsFull
-
-simplifySubstruc _ SsFull = SsFull
-simplifySubstruc _ SsNone = SsNone
-simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
-simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
-simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
-simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s)
-simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s)
-simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s)
-
--- simplifySubstruc' :: Substruc t t'
--- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r
--- simplifySubstruc' SsFull k = k SsFull ExMapId
--- simplifySubstruc' SsNone k = k SsNone ExMapId
--- simplifySubstruc' (SsPair s1 s2) k =
--- simplifySubstruc' s1 $ \s1' f1 ->
--- simplifySubstruc' s2 $ \s2' f2 ->
--- case (s1', s2') of
--- (SsFull, SsFull) ->
--- k SsFull (case (f1, f2) of
--- (ExMapId, ExMapId) -> ExMapId
--- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 ->
--- EPair ext (fromExMap f1 e1) (fromExMap f2 e2)))
--- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext))))
--- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ)))))
--- simplifySubstruc' _ _ = _
-
--- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b)
--- ssUnpair SsFull = (SsFull, SsFull)
--- ssUnpair SsNone = (SsNone, SsNone)
--- ssUnpair (SsPair a b) = (a, b)
-
--- ssUnleft :: Substruc (TEither a b) -> Substruc a
--- ssUnleft SsFull = SsFull
--- ssUnleft SsNone = SsNone
--- ssUnleft (SsEither a _) = a
-
--- ssUnright :: Substruc (TEither a b) -> Substruc b
--- ssUnright SsFull = SsFull
--- ssUnright SsNone = SsNone
--- ssUnright (SsEither _ b) = b
-
--- ssUnlleft :: Substruc (TLEither a b) -> Substruc a
--- ssUnlleft SsFull = SsFull
--- ssUnlleft SsNone = SsNone
--- ssUnlleft (SsLEither a _) = a
-
--- ssUnlright :: Substruc (TLEither a b) -> Substruc b
--- ssUnlright SsFull = SsFull
--- ssUnlright SsNone = SsNone
--- ssUnlright (SsLEither _ b) = b
-
--- ssUnjust :: Substruc (TMaybe a) -> Substruc a
--- ssUnjust SsFull = SsFull
--- ssUnjust SsNone = SsNone
--- ssUnjust (SsMaybe a) = a
-
--- ssUnarr :: Substruc (TArr n a) -> Substruc a
--- ssUnarr SsFull = SsFull
--- ssUnarr SsNone = SsNone
--- ssUnarr (SsArr a) = a
-
--- ssUnaccum :: Substruc (TAccum a) -> Substruc a
--- ssUnaccum SsFull = SsFull
--- ssUnaccum SsNone = SsNone
--- ssUnaccum (SsAccum a) = a
-
-
-type family MapEmpty env where
- MapEmpty '[] = '[]
- MapEmpty (t : env) = TNil : MapEmpty env
-
-data OccEnv a env env' where
- OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top!
- OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env')
-
-instance Semigroup a => Semigroup (Some (OccEnv a env)) where
- Some OccEnd <> e = e
- e <> Some OccEnd = e
- Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2)
-
-instance Semigroup a => Monoid (Some (OccEnv a env)) where
- mempty = Some OccEnd
-
-instance Occurrence a => Occurrence (Some (OccEnv a env)) where
- Some OccEnd <||> e = e
- e <||> Some OccEnd = e
- Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2)
-
- scaleMany (Some OccEnd) = Some OccEnd
- scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s)
-
-onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env)
-onehotOccEnv IZ v s = Some (OccPush OccEnd v s)
-onehotOccEnv (IS i) v s
- | Some env' <- onehotOccEnv i v s
- = Some (OccPush env' mempty SsNone)
-
-occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t')
-occEnvPop (OccPush e _ s) = (e, s)
-occEnvPop OccEnd = (OccEnd, SsNone)
-
-occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r
-occEnvPop' (OccPush e _ s) k = k e s
-occEnvPop' OccEnd k = k OccEnd SsNone
-
-occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env)
-occEnvPopSome (Some (OccPush e _ _)) = Some e
-occEnvPopSome (Some OccEnd) = Some OccEnd
-
-occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t))
-occEnvPrj OccEnd _ = mempty
-occEnvPrj (OccPush _ o s) IZ = (o, Some s)
-occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i
-
-occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env'))
-occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ)
-occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i'))
-occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ)
-occEnvPrjS (OccPush e _ _) (IS i)
- | Some (Pair s' i') <- occEnvPrjS e i
- = Some (Pair s' (IS i'))
-
-projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small
-projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
- _ | Just Refl <- testEquality topsbig topssmall -> ex
-
- (SsFull, SsFull) -> ex
- (SsNone, SsNone) -> ex
- (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller"
- (_, SsNone) ->
- case typeOf ex of
- STNil -> ex
- _ -> use ex $ ENil ext
-
- (SsPair s1 s2, SsPair s1' s2') ->
- eunPair ex $ \_ e1 e2 ->
- EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2)
- (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex
- (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex
-
- (SsEither s1 s2, SsEither s1' s2')
- | STEither t1 t2 <- typeOf ex ->
- let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
- e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
- in ecase ex
- (EInl ext (typeOf e2) e1)
- (EInr ext (typeOf e1) e2)
- (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex
- (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex
-
- (SsLEither s1 s2, SsLEither s1' s2')
- | STLEither t1 t2 <- typeOf ex ->
- let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
- e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
- in elcase ex
- (ELNil ext (typeOf e1) (typeOf e2))
- (ELInl ext (typeOf e2) e1)
- (ELInr ext (typeOf e1) e2)
- (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex
- (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex
-
- (SsMaybe s1, SsMaybe s1')
- | STMaybe t1 <- typeOf ex ->
- let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
- in emaybe ex
- (ENothing ext (typeOf e1))
- (EJust ext e1)
- (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
- (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
-
- (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex
- (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
- (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
-
- (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum"
- (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex
- (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex
-
-
--- | A boolean for each entry in the environment, with the ability to uniformly
--- mask the top part above a certain index.
-data EnvMask env where
- EMRest :: Bool -> EnvMask env
- EMPush :: EnvMask env -> Bool -> EnvMask (t : env)
-
-envMaskPrj :: EnvMask env -> Idx env t -> Bool
-envMaskPrj (EMRest b) _ = b
-envMaskPrj (_ `EMPush` b) IZ = b
-envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i
-
-occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx ex
- | Some env <- occCountAll ex
- = fst (occEnvPrj env idx)
-
-occCountAll :: Expr x env t -> Some (OccEnv Occ env)
-occCountAll ex = occCountX SsFull ex $ \env _ -> Some env
-
-pruneExpr :: SList f env -> Expr x env t -> Ex env t
-pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
- where
- fullOccEnv :: SList f env -> OccEnv () env env
- fullOccEnv SNil = OccEnd
- fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
-
--- In one traversal, count occurrences of variables and determine what parts of
--- expressions are actually used. These two results are computed independently:
--- even if (almost) nothing of a particular term is actually used, variable
--- references in that term still count as usual.
---
--- In @occCountX s t k@:
--- * s: how much of the result of this term is required
--- * t: the term to analyse
--- * k: is passed the actual environment usage of this expression, including
--- occurrence counts. The callback reconstructs a new expression in an
--- updated "response" environment. The response must be at least as large as
--- the computed usages.
-occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t
- -> (forall env'. OccEnv Occ env env'
- -- response OccEnv must be at least as large as the OccEnv returned above
- -> (forall env''. OccEnv () env env'' -> Ex env'' t')
- -> r)
- -> r
-occCountX initialS topexpr k = case topexpr of
- EVar _ t i ->
- withSome (onehotOccEnv i (Occ One One) s) $ \env ->
- k env $ \env' ->
- withSome (occEnvPrjS env' i) $ \(Pair s' i') ->
- projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i')
- ELet _ rhs body ->
- occCountX s body $ \envB mkbody ->
- occEnvPop' envB $ \envB' s1 ->
- occCountX s1 rhs $ \envR mkrhs ->
- withSome (Some envB' <> Some envR) $ \env ->
- k env $ \env' ->
- ELet ext (mkrhs env') (mkbody (OccPush env' () s1))
- EPair _ a b ->
- case s of
- SsNone ->
- occCountX SsNone a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $ use (mkb env') $ ENil ext
- SsPair' s1 s2 ->
- occCountX s1 a $ \env1 mka ->
- occCountX s2 b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EPair ext (mka env') (mkb env')
- EFst _ e ->
- occCountX (SsPair s SsNone) e $ \env1 mke ->
- k env1 $ \env' ->
- EFst ext (mke env')
- ESnd _ e ->
- occCountX (SsPair SsNone s) e $ \env1 mke ->
- k env1 $ \env' ->
- ESnd ext (mke env')
- ENil _ ->
- case s of
- SsFull -> k OccEnd (\_ -> ENil ext)
- SsNone -> k OccEnd (\_ -> ENil ext)
- EInl _ t e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- SsEither s1 s2 ->
- occCountX s1 e $ \env1 mke ->
- k env1 $ \env' ->
- EInl ext (applySubstruc s2 t) (mke env')
- SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
- EInr _ t e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- SsEither s1 s2 ->
- occCountX s2 e $ \env1 mke ->
- k env1 $ \env' ->
- EInr ext (applySubstruc s1 t) (mke env')
- SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
- ECase _ e a b ->
- occCountX s a $ \env1' mka ->
- occCountX s b $ \env2' mkb ->
- occEnvPop' env1' $ \env1 s1 ->
- occEnvPop' env2' $ \env2 s2 ->
- occCountX (SsEither s1 s2) e $ \env0 mke ->
- withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
- k env $ \env' ->
- ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
- ENothing _ t ->
- case s of
- SsNone -> k OccEnd (\_ -> ENil ext)
- SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t))
- SsFull -> occCountX (SsMaybe SsFull) topexpr k
- EJust _ e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- SsMaybe s' ->
- occCountX s' e $ \env1 mke ->
- k env1 $ \env' ->
- EJust ext (mke env')
- SsFull -> occCountX (SsMaybe SsFull) topexpr k
- EMaybe _ a b e ->
- occCountX s a $ \env1 mka ->
- occCountX s b $ \env2' mkb ->
- occEnvPop' env2' $ \env2 s2 ->
- occCountX (SsMaybe s2) e $ \env0 mke ->
- withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
- k env $ \env' ->
- EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
- ELNil _ t1 t2 ->
- case s of
- SsNone -> k OccEnd (\_ -> ENil ext)
- SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2))
- SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
- ELInl _ t e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- SsLEither s1 s2 ->
- occCountX s1 e $ \env1 mke ->
- k env1 $ \env' ->
- ELInl ext (applySubstruc s2 t) (mke env')
- SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
- ELInr _ t e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- SsLEither s1 s2 ->
- occCountX s2 e $ \env1 mke ->
- k env1 $ \env' ->
- ELInr ext (applySubstruc s1 t) (mke env')
- SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
- ELCase _ e a b c ->
- occCountX s a $ \env1 mka ->
- occCountX s b $ \env2' mkb ->
- occCountX s c $ \env3' mkc ->
- occEnvPop' env2' $ \env2 s1 ->
- occEnvPop' env3' $ \env3 s2 ->
- occCountX (SsLEither s1 s2) e $ \env0 mke ->
- withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env ->
- k env $ \env' ->
- ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
-
- EConstArr _ n t x ->
- case s of
- SsNone -> k OccEnd (\_ -> ENil ext)
- SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
- SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
-
- EBuild _ n a b ->
- case s of
- SsNone ->
- occCountX SsFull a $ \env1 mka ->
- occCountX SsNone b $ \env2'' mkb ->
- occEnvPop' env2'' $ \env2' s2 ->
- withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
- k env $ \env' ->
- use (EBuild ext n (mka env') $
- use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
- ENil ext) $
- ENil ext
- SsArr' s' ->
- occCountX SsFull a $ \env1 mka ->
- occCountX s' b $ \env2'' mkb ->
- occEnvPop' env2'' $ \env2' s2 ->
- withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
- k env $ \env' ->
- EBuild ext n (mka env') $
- elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
-
- EMap _ a b ->
- case s of
- SsNone ->
- occCountX SsNone a $ \env1'' mka ->
- occEnvPop' env1'' $ \env1' s1 ->
- occCountX (SsArr s1) b $ \env2 mkb ->
- withSome (scaleMany (Some env1') <> Some env2) $ \env ->
- k env $ \env' ->
- use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
- ENil ext
- SsArr' s' ->
- occCountX s' a $ \env1'' mka ->
- occEnvPop' env1'' $ \env1' s1 ->
- occCountX (SsArr s1) b $ \env2 mkb ->
- withSome (scaleMany (Some env1') <> Some env2) $ \env ->
- k env $ \env' ->
- EMap ext (mka (OccPush env' () s1)) (mkb env')
-
- EFold1Inner _ commut a b c ->
- occCountX SsFull a $ \env1'' mka ->
- occEnvPop' env1'' $ \env1' s1' ->
- let s1 = case s1' of
- SsNone -> Some SsNone
- SsPair' s1'a s1'b -> Some s1'a <> Some s1'b
- s0 = case s of
- SsNone -> Some SsNone
- SsArr' s' -> Some s' in
- withSome (s1 <> s0) $ \sElt ->
- occCountX sElt b $ \env2 mkb ->
- occCountX (SsArr sElt) c $ \env3 mkc ->
- withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
- k env $ \env' ->
- projectSmallerSubstruc (SsArr sElt) s $
- EFold1Inner ext commut
- (projectSmallerSubstruc SsFull sElt $
- mka (OccPush env' () (SsPair sElt sElt)))
- (mkb env') (mkc env')
-
- ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
-
- EUnit _ e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env mke ->
- k env $ \env' ->
- use (mke env') $ ENil ext
- SsArr' s' ->
- occCountX s' e $ \env mke ->
- k env $ \env' ->
- EUnit ext (mke env')
-
- EReplicate1Inner _ a b ->
- case s of
- SsNone ->
- occCountX SsNone a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $ use (mkb env') $ ENil ext
- SsArr' s' ->
- occCountX SsFull a $ \env1 mka ->
- occCountX (SsArr s') b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EReplicate1Inner ext (mka env') (mkb env')
-
- EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
-
- EReshape _ n esh e ->
- case s of
- SsNone ->
- occCountX SsNone esh $ \env1 mkesh ->
- occCountX SsNone e $ \env2 mke ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mkesh env') $ use (mke env') $ ENil ext
- SsArr' s' ->
- occCountX SsFull esh $ \env1 mkesh ->
- occCountX (SsArr s') e $ \env2 mke ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EReshape ext n (mkesh env') (mke env')
-
- EZip _ a b ->
- case s of
- SsNone ->
- occCountX SsNone a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $ use (mkb env') $ ENil ext
- SsArr' SsNone ->
- occCountX (SsArr SsNone) a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mkb env') $ mka env'
- SsArr' (SsPair' SsNone s2) ->
- occCountX SsNone a $ \env1 mka ->
- occCountX (SsArr s2) b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $
- emap (EPair ext (ENil ext) (evar IZ)) (mkb env')
- SsArr' (SsPair' s1 SsNone) ->
- occCountX (SsArr s1) a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mkb env') $
- emap (EPair ext (evar IZ) (ENil ext)) (mka env')
- SsArr' (SsPair' s1 s2) ->
- occCountX (SsArr s1) a $ \env1 mka ->
- occCountX (SsArr s2) b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EZip ext (mka env') (mkb env')
-
- EFold1InnerD1 _ cm e1 e2 e3 ->
- case s of
- -- If nothing is necessary, we can execute a fold and then proceed to ignore it
- SsNone ->
- let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
- (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
- in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex
- -- If we don't need the stores, still a fold suffices
- SsPair' sP SsNone ->
- let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
- (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
- in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext)
- -- If for whatever reason the additional stores themselves are
- -- unnecessary but the shape of the array is, then oblige
- SsPair' sP (SsArr' SsNone) ->
- let STArr sn _ = typeOf e3
- foldex =
- elet (mapExt (\_ -> ext) e3) $
- EPair ext
- (EShape ext (evar IZ))
- (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1)))
- (mapExt (\_ -> ext) (weakenExpr WSink e2))
- (evar IZ))
- in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
- k env1 $ \env' ->
- eunPair (mkfoldex env') $ \_ eshape earr ->
- EPair ext earr (EBuild ext sn eshape (ENil ext))
- -- If at least some of the additional stores are required, we need to keep this a mapAccum
- SsPair' _ (SsArr' sB) ->
- -- TODO: propagate usage of primals
- occCountX (SsPair SsFull sB) e1 $ \env1_1' mka ->
- occEnvPop' env1_1' $ \env1' _ ->
- occCountX SsFull e2 $ \env2 mkb ->
- occCountX SsFull e3 $ \env3 mkc ->
- withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
- k env $ \env' ->
- projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
- EFold1InnerD1 ext cm (mka (OccPush env' () SsFull))
- (mkb env') (mkc env')
-
- EFold1InnerD2 _ cm ef ebog ed ->
- -- TODO: propagate usage of duals
- occCountX SsFull ef $ \env1_2' mkef ->
- occEnvPop' env1_2' $ \env1_1' _ ->
- occEnvPop' env1_1' $ \env1' sB ->
- occCountX (SsArr sB) ebog $ \env2 mkebog ->
- occCountX SsFull ed $ \env3 mked ->
- withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
- k env $ \env' ->
- projectSmallerSubstruc SsFull s $
- EFold1InnerD2 ext cm
- (mkef (OccPush (OccPush env' () sB) () SsFull))
- (mkebog env') (mked env')
-
- EConst _ t x ->
- k OccEnd $ \_ ->
- case s of
- SsNone -> ENil ext
- SsFull -> EConst ext t x
-
- EIdx0 _ e ->
- occCountX (SsArr s) e $ \env1 mke ->
- k env1 $ \env' ->
- EIdx0 ext (mke env')
-
- EIdx1 _ a b ->
- case s of
- SsNone ->
- occCountX SsNone a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $ use (mkb env') $ ENil ext
- SsArr' s' ->
- occCountX (SsArr s') a $ \env1 mka ->
- occCountX SsFull b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EIdx1 ext (mka env') (mkb env')
-
- EIdx _ a b ->
- case s of
- SsNone ->
- occCountX SsNone a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $ use (mkb env') $ ENil ext
- _ ->
- occCountX (SsArr s) a $ \env1 mka ->
- occCountX SsFull b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EIdx ext (mka env') (mkb env')
-
- EShape _ e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- _ ->
- occCountX (SsArr SsNone) e $ \env1 mke ->
- k env1 $ \env' ->
- projectSmallerSubstruc SsFull s $ EShape ext (mke env')
-
- EOp _ op e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env1 mke ->
- k env1 $ \env' ->
- use (mke env') $ ENil ext
- _ ->
- occCountX SsFull e $ \env1 mke ->
- k env1 $ \env' ->
- projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
-
- ECustom _ t1 t2 t3 e1 e2 e3 a b
- | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
- error "Accumulators not allowed in input/output/tape of an ECustom"
- | otherwise ->
- case s of
- SsNone ->
- -- Allowed to ignore e1/e2/e3 here because no accumulators are
- -- communicated, and hence no relevant effects exist
- occCountX SsNone a $ \env1 mka ->
- occCountX SsNone b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (mka env') $ use (mkb env') $ ENil ext
- s' -> -- Let's be pessimistic for safety
- occCountX SsFull a $ \env1 mka ->
- occCountX SsFull b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- projectSmallerSubstruc SsFull s' $
- ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
-
- ERecompute _ e ->
- occCountX s e $ \env1 mke ->
- k env1 $ \env' ->
- ERecompute ext (mke env')
-
- EWith _ t a b ->
- case s of
- SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with
- occCountX SsNone b $ \env2' mkb ->
- occEnvPop' env2' $ \env2 s1 ->
- withSome (case s1 of
- SsFull -> Some SsFull
- SsAccum s' -> Some s'
- SsNone -> Some SsNone) $ \s1' ->
- occCountX s1' a $ \env1 mka ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $
- ENil ext
- SsPair sB sA ->
- occCountX sB b $ \env2' mkb ->
- occEnvPop' env2' $ \env2 s1 ->
- let s1' = case s1 of
- SsFull -> Some SsFull
- SsAccum s' -> Some s'
- SsNone -> Some SsNone in
- withSome (Some sA <> s1') $ \sA' ->
- occCountX sA' a $ \env1 mka ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $
- EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA')))
- SsFull -> occCountX (SsPair SsFull SsFull) topexpr k
-
- EAccum _ t p a sp b e ->
- -- TODO: do better!
- occCountX SsFull a $ \env1 mka ->
- occCountX SsFull b $ \env2 mkb ->
- occCountX SsFull e $ \env3 mke ->
- withSome (Some env1 <> Some env2) $ \env12 ->
- withSome (Some env12 <> Some env3) $ \env ->
- k env $ \env' ->
- case s of {SsFull -> id; SsNone -> id} $
- EAccum ext t p (mka env') sp (mkb env') (mke env')
-
- EZero _ t e ->
- occCountX (subZeroInfo s) e $ \env1 mke ->
- k env1 $ \env' ->
- EZero ext (applySubstrucM s t) (mke env')
- where
- subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2)
- subZeroInfo SsFull = SsFull
- subZeroInfo SsNone = SsNone
- subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2)
- subZeroInfo SsEither{} = error "Either is not a monoid"
- subZeroInfo SsLEither{} = SsNone
- subZeroInfo SsMaybe{} = SsNone
- subZeroInfo (SsArr s') = SsArr (subZeroInfo s')
- subZeroInfo SsAccum{} = error "Accum is not a monoid"
-
- EDeepZero _ t e ->
- occCountX (subDeepZeroInfo s) e $ \env1 mke ->
- k env1 $ \env' ->
- EDeepZero ext (applySubstrucM s t) (mke env')
- where
- subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2)
- subDeepZeroInfo SsFull = SsFull
- subDeepZeroInfo SsNone = SsNone
- subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2)
- subDeepZeroInfo SsEither{} = error "Either is not a monoid"
- subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2)
- subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s')
- subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s')
- subDeepZeroInfo SsAccum{} = error "Accum is not a monoid"
-
- EPlus _ t a b ->
- occCountX s a $ \env1 mka ->
- occCountX s b $ \env2 mkb ->
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- EPlus ext (applySubstrucM s t) (mka env') (mkb env')
-
- EOneHot _ t p a b ->
- occCountX SsFull a $ \env1 mka ->
- occCountX SsFull b $ \env2 mkb -> -- TODO: do better
- withSome (Some env1 <> Some env2) $ \env ->
- k env $ \env' ->
- projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env')
-
- EError _ t msg ->
- k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
- where
- s = simplifySubstruc (typeOf topexpr) initialS
-
- handleReduction :: t ~ TArr n (TScal t2)
- => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
- -> Expr x env (TArr (S n) (TScal t2))
- -> r
- handleReduction reduce e
- | STArr (SS n) _ <- typeOf e =
- case s of
- SsNone ->
- occCountX SsNone e $ \env mke ->
- k env $ \env' ->
- use (mke env') $ ENil ext
- SsArr' SsNone ->
- occCountX (SsArr SsNone) e $ \env mke ->
- k env $ \env' ->
- elet (mke env') $
- EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr' SsFull ->
- occCountX (SsArr SsFull) e $ \env mke ->
- k env $ \env' ->
- reduce (mke env')
-
-
-deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
-deleteUnused SNil (Some OccEnd) k = k SETop
-deleteUnused (_ `SCons` env) (Some OccEnd) k =
- deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub)
-deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k =
- deleteUnused env (Some occenv) $ \sub ->
- case count of Zero -> k (SENo sub)
- _ -> k (SEYesR sub)
-
-unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
-unsafeWeakenWithSubenv = \sub ->
- subst (\x t i -> case sinkViaSubenv i sub of
- Just i' -> EVar x t i'
- Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
- where
- sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkViaSubenv IZ (SEYesR _) = Just IZ
- sinkViaSubenv IZ (SENo _) = Nothing
- sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
- sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
deleted file mode 100644
index 85faba3..0000000
--- a/src/AST/Env.hs
+++ /dev/null
@@ -1,95 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.Env where
-
-import Data.Type.Equality
-
-import AST.Sparse
-import AST.Weaken
-import CHAD.Types
-import Data
-
-
--- | @env'@ is a subset of @env@: each element of @env@ is either included in
--- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv' s env env' where
- SETop :: Subenv' s '[] '[]
- SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env')
- SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env'
-deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env')
-
-type Subenv = Subenv' (:~:)
-type SubenvS = Subenv' Sparse
-
-pattern SEYesR :: forall tenv tenv'. ()
- => forall t env env'. (tenv ~ t : env, tenv' ~ t : env')
- => Subenv env env' -> Subenv tenv tenv'
-pattern SEYesR s = SEYes Refl s
-
-{-# COMPLETE SETop, SEYesR, SENo #-}
-
-subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'
-subList SNil SETop = SNil
-subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)
-subList (SCons _ xs) (SENo sub) = subList xs sub
-
-subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env
-subenvAll SNil = SETop
-subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env)
-
-subenvNone :: SList f env -> Subenv' s env '[]
-subenvNone SNil = SETop
-subenvNone (SCons _ env) = SENo (subenvNone env)
-
-subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t']
-subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
-subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
-subenvOnehot SNil i _ = case i of {}
-
-subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
-subenvCompose SETop SETop = SETop
-subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
-subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
-subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
-
-subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')
-subenvConcat sub1 SETop = sub1
-subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)
-subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
-
--- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2
--- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r
--- subenvSplit SNil sub k = k SETop sub
--- subenvSplit (SCons _ list) (SENo sub) k =
--- subenvSplit list sub $ \sub1 sub2 ->
--- k (SENo sub1) sub2
--- subenvSplit (SCons _ list) (SEYes s sub) k =
--- subenvSplit list sub $ \sub1 sub2 ->
--- k (SEYes s sub1) sub2
-
-sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0
-sinkWithSubenv SETop = WId
-sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub
-sinkWithSubenv (SENo sub) = sinkWithSubenv sub
-
-wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env
-wUndoSubenv SETop = WId
-wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)
-wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
-
-subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env'
-subenvMap _ SNil SETop = SETop
-subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub)
-subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub)
-
-subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env')
-subenvD2E SETop = SETop
-subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub)
-subenvD2E (SENo sub) = SENo (subenvD2E sub)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
deleted file mode 100644
index bbcfd9e..0000000
--- a/src/AST/Pretty.hs
+++ /dev/null
@@ -1,525 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE TupleSections #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
-
-import Control.Monad (ap)
-import Data.List (intersperse, intercalate)
-import Data.Functor.Const
-import qualified Data.Functor.Product as Product
-import Data.String (fromString)
-import Prettyprinter
-import Prettyprinter.Render.String
-
-import qualified Data.Text.Lazy as TL
-import qualified Prettyprinter.Render.Terminal as PT
-import System.Console.ANSI (hSupportsANSI)
-import System.IO (stdout)
-import System.IO.Unsafe (unsafePerformIO)
-
-import AST
-import AST.Count
-import AST.Sparse.Types
-import CHAD.Types
-import Data
-
-
-class PrettyX x where
- prettyX :: x t -> String
-
- prettyXsuffix :: x t -> String
- prettyXsuffix x = "<" ++ prettyX x ++ ">"
-
-instance PrettyX (Const ()) where
- prettyX _ = ""
- prettyXsuffix _ = ""
-
-
-type SVal = SList (Const String)
-
-newtype M a = M { runM :: Int -> (a, Int) }
- deriving (Functor)
-instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
-instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) }
-
-genId :: M Int
-genId = M (\i -> (i, i + 1))
-
-nameBaseForType :: STy t -> String
-nameBaseForType STNil = "nil"
-nameBaseForType (STPair{}) = "p"
-nameBaseForType (STEither{}) = "e"
-nameBaseForType (STMaybe{}) = "m"
-nameBaseForType (STScal STI32) = "n"
-nameBaseForType (STScal STI64) = "n"
-nameBaseForType (STArr{}) = "a"
-nameBaseForType (STAccum{}) = "ac"
-nameBaseForType _ = "x"
-
-genName' :: String -> M String
-genName' prefix = (prefix ++) . show <$> genId
-
-genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
-genNameIfUsedIn' prefix ty idx ex
- | occCount idx ex == mempty = case ty of STNil -> return "()"
- _ -> return "_"
- | otherwise = genName' prefix
-
--- TODO: let this return a type-tagged thing so that name environments are more typed than Const
-genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
-genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
-
-pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO ()
-pprintExpr = putStrLn . ppExpr knownEnv
-
-ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String
-ppExpr senv e = render $ fst . flip runM 1 $ do
- val <- mkVal senv
- e' <- ppExpr' 0 val e
- let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "."
- return $ group $ flatAlt
- (hang 2 $
- ppString lam
- <> hardline <> e')
- (ppString lam <+> e')
- where
- mkVal :: SList f env -> M (SVal env)
- mkVal SNil = return SNil
- mkVal (SCons _ v) = do
- val <- mkVal v
- name <- genName' "arg"
- return (Const name `SCons` val)
-
-ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
-ppExpr' d val expr = case expr of
- EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr
-
- e@ELet{} -> ppExprLet d val e
-
- EPair _ a b -> do
- a' <- ppExpr' 0 val a
- b' <- ppExpr' 0 val b
- return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr)
- (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr)
-
- EFst _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e'
-
- ESnd _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e'
-
- ENil _ -> return $ ppString "()"
-
- EInl _ _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e'
-
- EInr _ _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e'
-
- ECase _ e a b -> do
- e' <- ppExpr' 0 val e
- let STEither t1 t2 = typeOf e
- name1 <- genNameIfUsedIn t1 IZ a
- a' <- ppExpr' 0 (Const name1 `SCons` val) a
- name2 <- genNameIfUsedIn t2 IZ b
- b' <- ppExpr' 0 (Const name2 `SCons` val) b
- return $ ppParen (d > 0) $
- hang 2 $
- annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
- <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a'
- <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b'
-
- ENothing _ _ -> return $ ppString "Nothing"
-
- EJust _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e'
-
- EMaybe _ a b e -> do
- let STMaybe t = typeOf e
- e' <- ppExpr' 0 val e
- a' <- ppExpr' 0 val a
- name <- genNameIfUsedIn t IZ b
- b' <- ppExpr' 0 (Const name `SCons` val) b
- return $ ppParen (d > 0) $
- align $
- group (flatAlt
- (annotate AKey (ppString "case") <> ppX expr <+> e'
- <> hardline <> annotate AKey (ppString "of"))
- (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")))
- <> hardline
- <> indent 2
- (ppString "Nothing" <+> ppString "->" <+> a'
- <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b')
-
- ELNil _ _ _ -> return (ppString "LNil")
-
- ELInl _ _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e'
-
- ELInr _ _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e'
-
- ELCase _ e a b c -> do
- e' <- ppExpr' 0 val e
- let STLEither t1 t2 = typeOf e
- a' <- ppExpr' 11 val a
- name1 <- genNameIfUsedIn t1 IZ b
- b' <- ppExpr' 0 (Const name1 `SCons` val) b
- name2 <- genNameIfUsedIn t2 IZ c
- c' <- ppExpr' 0 (Const name2 `SCons` val) c
- return $ ppParen (d > 0) $
- hang 2 $
- annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of")
- <> hardline <> ppString "LNil" <+> ppString "->" <+> a'
- <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b'
- <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c'
-
- EConstArr _ _ ty v
- | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
-
- EBuild _ n a b -> do
- a' <- ppExpr' 11 val a
- name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
- e' <- ppExpr' 0 (Const name `SCons` val) b
- let primName = ppString ("build" ++ intSubscript (fromSNat n))
- return $ ppParen (d > 0) $
- group $ flatAlt
- (hang 2 $
- annotate AHighlight primName <> ppX expr <+> a'
- <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
- <> hardline <> e')
- (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
-
- EMap _ a b -> do
- let STArr _ t1 = typeOf b
- name <- genNameIfUsedIn t1 IZ a
- a' <- ppExpr' 0 (Const name `SCons` val) a
- b' <- ppExpr' 11 val b
- return $ ppParen (d > 0) $
- ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
-
- EFold1Inner _ cm a b c -> do
- name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a
- a' <- ppExpr' 0 (Const name `SCons` val) a
- b' <- ppExpr' 11 val b
- c' <- ppExpr' 11 val c
- let opname = "fold1i" ++ ppCommut cm
- return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
-
- ESum1Inner _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e'
-
- EUnit _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e'
-
- EReplicate1Inner _ a b -> do
- a' <- ppExpr' 11 val a
- b' <- ppExpr' 11 val b
- return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b']
-
- EMaximum1Inner _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e'
-
- EMinimum1Inner _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
-
- EReshape _ n esh e -> do
- esh' <- ppExpr' 11 val esh
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e']
-
- EZip _ e1 e2 -> do
- e1' <- ppExpr' 11 val e1
- e2' <- ppExpr' 11 val e2
- return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
-
- EFold1InnerD1 _ cm a b c -> do
- name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a
- a' <- ppExpr' 0 (Const name `SCons` val) a
- b' <- ppExpr' 11 val b
- c' <- ppExpr' 11 val c
- let opname = "fold1iD1" ++ ppCommut cm
- return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
-
- EFold1InnerD2 _ cm ef ebog ed -> do
- let STArr _ tB = typeOf ebog
- STArr _ t2 = typeOf ed
- namef1 <- genNameIfUsedIn tB (IS IZ) ef
- namef2 <- genNameIfUsedIn t2 IZ ef
- ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
- ebog' <- ppExpr' 11 val ebog
- ed' <- ppExpr' 11 val ed
- let opname = "fold1iD2" ++ ppCommut cm
- return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr)
- [ppLam [ppString namef1, ppString namef2] ef', ebog', ed']
-
- EConst _ ty v
- | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
-
- EIdx0 _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e'
-
- EIdx1 _ a b -> do
- a' <- ppExpr' 9 val a
- b' <- ppExpr' 9 val b
- return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b'
-
- EIdx _ a b -> do
- a' <- ppExpr' 9 val a
- b' <- ppExpr' 10 val b
- return $ ppParen (d > 8) $
- a' <+> ppString "!" <> ppX expr <+> b'
-
- EShape _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e'
-
- EOp _ op (EPair _ a b)
- | (Infix, ops) <- operator op -> do
- a' <- ppExpr' 9 val a
- b' <- ppExpr' 9 val b
- return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b'
-
- EOp _ op e -> do
- e' <- ppExpr' 11 val e
- let ops = case operator op of
- (Infix, s) -> "(" ++ s ++ ")"
- (Prefix, s) -> s
- return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e'
-
- ECustom _ t1 t2 t3 a b c e1 e2 -> do
- en1 <- genNameIfUsedIn t1 (IS IZ) a
- en2 <- genNameIfUsedIn t2 IZ a
- pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b
- pn2 <- genNameIfUsedIn (d1 t2) IZ b
- dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c
- dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
- a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a
- b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b
- c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c
- e1' <- ppExpr' 11 val e1
- e2' <- ppExpr' 11 val e2
- return $ ppParen (d > 10) $
- ppApp (ppString "custom" <> ppX expr)
- [ppLam [ppString en1, ppString en2] a'
- ,ppLam [ppString pn1, ppString pn2] b'
- ,ppLam [ppString dn1, ppString dn2] c'
- ,e1'
- ,e2']
-
- ERecompute _ e -> do
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
-
- EWith _ t e1 e2 -> do
- e1' <- ppExpr' 11 val e1
- name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
- e2' <- ppExpr' 0 (Const name `SCons` val) e2
- return $ ppParen (d > 0) $
- group $ flatAlt
- (hang 2 $
- annotate AWith (ppString "with") <> ppX expr <+> e1'
- <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->"
- <> hardline <> e2')
- (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
-
- EAccum _ t prj e1 sp e2 e3 -> do
- e1' <- ppExpr' 11 val e1
- e2' <- ppExpr' 11 val e2
- e3' <- ppExpr' 11 val e3
- return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
- [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
-
- EZero _ t e1 -> do
- e1' <- ppExpr' 11 val e1
- return $ ppParen (d > 0) $
- annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
-
- EDeepZero _ t e1 -> do
- e1' <- ppExpr' 11 val e1
- return $ ppParen (d > 0) $
- annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
-
- EPlus _ t a b -> do
- a' <- ppExpr' 11 val a
- b' <- ppExpr' 11 val b
- return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b']
-
- EOneHot _ t prj a b -> do
- a' <- ppExpr' 11 val a
- b' <- ppExpr' 11 val b
- return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b']
-
- EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
-
-ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
-ppExprLet d val etop = do
- let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
- collect val' (ELet _ rhs body) = do
- let occ = occCount IZ body
- name <- genNameIfUsedIn (typeOf rhs) IZ body
- rhs' <- ppExpr' 0 val' rhs
- (binds, core) <- collect (Const name `SCons` val') body
- return ((name, occ, rhs') : binds, core)
- collect val' e = ([],) <$> ppExpr' 0 val' e
-
- (binds, core) <- collect val etop
-
- return $ ppParen (d > 0) $
- align $
- annotate AKey (ppString "let")
- <+> align (mconcat $ intersperse hardline $
- map (\(name, _occ, rhs) ->
- ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs)
- binds)
- <> hardline <> annotate AKey (ppString "in") <+> core
-
-ppApp :: ADoc -> [ADoc] -> ADoc
-ppApp fun args = group $ fun <+> align (sep args)
-
-ppLam :: [ADoc] -> ADoc -> ADoc
-ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
- <> softline <> body <> ppString ")")
-
-ppAcPrj :: SMTy a -> SAcPrj p a b -> String
-ppAcPrj _ SAPHere = "."
-ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)"
-ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")"
-ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)"
-ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
-ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
-ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
-
-ppSparse :: SMTy a -> Sparse a b -> String
-ppSparse t sp | Just Refl <- isDense t sp = "D"
-ppSparse _ SpAbsent = "A"
-ppSparse t (SpSparse s) = "S" ++ ppSparse t s
-ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
-ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
-ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
-ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
-ppSparse (SMTScal _) SpScal = "."
-
-ppCommut :: Commutative -> String
-ppCommut Commut = "(C)"
-ppCommut Noncommut = ""
-
-ppX :: PrettyX x => Expr x env t -> ADoc
-ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
-
-data Fixity = Prefix | Infix
- deriving (Show)
-
-operator :: SOp a t -> (Fixity, String)
-operator OAdd{} = (Infix, "+")
-operator OMul{} = (Infix, "*")
-operator ONeg{} = (Prefix, "negate")
-operator OLt{} = (Infix, "<")
-operator OLe{} = (Infix, "<=")
-operator OEq{} = (Infix, "==")
-operator ONot = (Prefix, "not")
-operator OAnd = (Infix, "&&")
-operator OOr = (Infix, "||")
-operator OIf = (Prefix, "ifB")
-operator ORound64 = (Prefix, "round")
-operator OToFl64 = (Prefix, "toFl64")
-operator ORecip{} = (Prefix, "recip")
-operator OExp{} = (Prefix, "exp")
-operator OLog{} = (Prefix, "log")
-operator OIDiv{} = (Infix, "`div`")
-operator OMod{} = (Infix, "`mod`")
-
-ppSTy :: Int -> STy t -> String
-ppSTy d ty = render $ ppSTy' d ty
-
-ppSTy' :: Int -> STy t -> Doc q
-ppSTy' _ STNil = ppString "1"
-ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
-ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
-ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
-ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
-ppSTy' d (STArr n t) = ppParen (d > 10) $
- ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
-ppSTy' _ (STScal sty) = ppString $ case sty of
- STI32 -> "i32"
- STI64 -> "i64"
- STF32 -> "f32"
- STF64 -> "f64"
- STBool -> "bool"
-ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
-
-ppSMTy :: Int -> SMTy t -> String
-ppSMTy d ty = render $ ppSMTy' d ty
-
-ppSMTy' :: Int -> SMTy t -> Doc q
-ppSMTy' _ SMTNil = ppString "1"
-ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b
-ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b
-ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t
-ppSMTy' d (SMTArr n t) = ppParen (d > 10) $
- ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t
-ppSMTy' _ (SMTScal sty) = ppString $ case sty of
- STI32 -> "i32"
- STI64 -> "i64"
- STF32 -> "f32"
- STF64 -> "f64"
-
-ppString :: String -> Doc x
-ppString = fromString
-
-ppParen :: Bool -> Doc x -> Doc x
-ppParen True = parens
-ppParen False = id
-
-intSubscript :: Int -> String
-intSubscript = \case 0 -> "₀"
- n | n < 0 -> '₋' : go (-n) ""
- | otherwise -> go n ""
- where go 0 suff = suff
- go n suff = let (q, r) = n `quotRem` 10
- in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff)
-
-data Annot = AKey | AWith | AHighlight | AMonoid | AExt
- deriving (Show)
-
-annotToANSI :: Annot -> PT.AnsiStyle
-annotToANSI AKey = PT.bold
-annotToANSI AWith = PT.color PT.Red <> PT.underlined
-annotToANSI AHighlight = PT.color PT.Blue
-annotToANSI AMonoid = PT.color PT.Green
-annotToANSI AExt = PT.colorDull PT.White
-
-type ADoc = Doc Annot
-
-render :: Doc Annot -> String
-render =
- (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI
- else renderString)
- . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 }
- where
- stdoutTTY = unsafePerformIO $ hSupportsANSI stdout
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
deleted file mode 100644
index 2a29799..0000000
--- a/src/AST/Sparse.hs
+++ /dev/null
@@ -1,287 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImpredicativeTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE RankNTypes #-}
-
-{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
-module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
-
-import Data.Type.Equality
-
-import AST
-import AST.Sparse.Types
-import Data (SBool(..))
-
-
-sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
-sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
-sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
-sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
-sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
- eunPair e1 $ \w1 e1a e1b ->
- eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
- EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
- (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
-sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
- elet e2 $
- elcase (weakenExpr WSink e1)
- (evar IZ)
- (elcase (evar (IS IZ))
- (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
- (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
- (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
- (elcase (evar (IS IZ))
- (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
- (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
- (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
-sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
- elet e2 $
- emaybe (weakenExpr WSink e1)
- (evar IZ)
- (emaybe (evar (IS IZ))
- (EJust ext (evar IZ))
- (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
-sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
-sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
-
-
-cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
-cheapZero SMTNil = Just (ENil ext)
-cheapZero (SMTPair t1 t2)
- | Just e1 <- cheapZero t1
- , Just e2 <- cheapZero t2
- = Just (EPair ext e1 e2)
- | otherwise
- = Nothing
-cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
-cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
-cheapZero SMTArr{} = Nothing
-cheapZero (SMTScal t) = case t of
- STI32 -> Just (EConst ext t 0)
- STI64 -> Just (EConst ext t 0)
- STF32 -> Just (EConst ext t 0.0)
- STF64 -> Just (EConst ext t 0.0)
-
-
-data Injection sp a b where
- -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
- -- 'sparsePlusS' can provide injections even if the caller doesn't require
- -- them. This simplifies the sparsePlusS code.
- Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
- Noinj :: Injection False a b
-
-withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
-withInj (Inj f) k = Inj (k f)
-withInj Noinj _ = Noinj
-
-withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
- -> ((forall e. Ex e a1 -> Ex e b1)
- -> (forall e. Ex e a2 -> Ex e b2)
- -> (forall e'. Ex e' a' -> Ex e' b'))
- -> Injection sp a' b'
-withInj2 (Inj f) (Inj g) k = Inj (k f g)
-withInj2 Noinj _ _ = Noinj
-withInj2 _ Noinj _ = Noinj
-
--- | This function produces quadratically-sized code in the presence of nested
--- dynamic sparsity. TODO can this be improved?
-sparsePlusS
- :: SBool inj1 -> SBool inj2
- -> SMTy t -> Sparse t t1 -> Sparse t t2
- -> (forall t3. Sparse t t3
- -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
- -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
- -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
- -> r)
- -> r
--- nil override (but don't destroy effects!)
-sparsePlusS _ _ SMTNil _ _ k =
- k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
-
--- simplifications
-sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
- sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
- k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
-sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
- sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
- k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
-
-sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
- let ta = applySparse sp1 (fromSMTy t) in
- sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
- k sp3
- (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
- minj2
- (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
-sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
- let tb = applySparse sp2 (fromSMTy t) in
- sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
- k sp3
- minj1
- (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
- (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
-
-sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
- let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
- sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
- k sp3
- (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
- minj2
- (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
-sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
- let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
- sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
- k sp3
- minj1
- (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
- (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
-
-sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
- let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
- sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
- k sp3
- (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
- minj2
- (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
-sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
- let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
- sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
- k sp3
- minj1
- (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
- (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
-sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
-sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
-
--- TODO: sparse of Just is just Maybe
-
--- dense plus
-sparsePlusS _ _ t sp1 sp2 k
- | Just Refl <- isDense t sp1
- , Just Refl <- isDense t sp2
- = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
-
--- handle absents
-sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
-sparsePlusS ST _ t SpAbsent sp2 k
- | Just zero2 <- cheapZero (applySparse sp2 t) =
- k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
- | otherwise =
- k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
-
-sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
-sparsePlusS _ ST t sp1 SpAbsent k
- | Just zero1 <- cheapZero (applySparse sp1 t) =
- k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
- | otherwise =
- k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
-
--- double sparse yields sparse
-sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpSparse sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (emaybe (evar IZ)
- (ENothing ext (applySparse sp3 (fromSMTy t)))
- (EJust ext (inj2 (evar IZ))))
- (emaybe (evar (IS IZ))
- (EJust ext (inj1 (evar IZ)))
- (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-
--- single sparse can yield non-sparse if the other argument is always present
-sparsePlusS SF _ t (SpSparse sp1) sp2 k =
- sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
- k sp3 Noinj (Inj inj2)
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (inj2 (evar IZ))
- (plus (evar IZ) (evar (IS IZ))))
-sparsePlusS ST _ t (SpSparse sp1) sp2 k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpSparse sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> EJust ext (inj2 b))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (EJust ext (inj2 (evar IZ)))
- (EJust ext (plus (evar IZ) (evar (IS IZ)))))
-sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
- sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
- k sp3 inj2 inj1 (flip plus)
-
--- products
-sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
- sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
- sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
- k (SpPair sp3a sp3b)
- (withInj2 minj13a minj13b $ \inj13a inj13b ->
- \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
- (withInj2 minj23a minj23b $ \inj23a inj23b ->
- \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
- (\x1 x2 ->
- eunPair x1 $ \w1 x1a x1b ->
- eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
- EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
-
--- coproducts
-sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
- sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
- sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
- let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
- inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
- inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
- in
- k (SpLEither sp3a sp3b)
- (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
- (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
- (\x1 x2 ->
- elet x2 $
- elcase (weakenExpr WSink x1)
- (elcase (evar IZ)
- nil
- (inl (inj23a (evar IZ)))
- (inr (inj23b (evar IZ))))
- (elcase (evar (IS IZ))
- (inl (inj13a (evar IZ)))
- (inl (plusa (evar (IS IZ)) (evar IZ)))
- (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
- (elcase (evar (IS IZ))
- (inr (inj13b (evar IZ)))
- (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
- (inr (plusb (evar (IS IZ)) (evar IZ)))))
-
--- maybe
-sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpMaybe sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (emaybe (evar IZ)
- (ENothing ext (applySparse sp3 (fromSMTy t)))
- (EJust ext (inj2 (evar IZ))))
- (emaybe (evar (IS IZ))
- (EJust ext (inj1 (evar IZ)))
- (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-
--- dense array cotangents simply recurse
-sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
- sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
- k (SpArr sp3)
- (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
- (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
- (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
- (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
-
--- scalars
-sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs
deleted file mode 100644
index 10cac4e..0000000
--- a/src/AST/Sparse/Types.hs
+++ /dev/null
@@ -1,107 +0,0 @@
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.Sparse.Types where
-
-import AST.Types
-
-import Data.Kind (Type, Constraint)
-import Data.Type.Equality
-
-
-data Sparse t t' where
- SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
- SpAbsent :: Sparse t TNil
-
- SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
- SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
- SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
- SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
- SpScal :: Sparse (TScal t) (TScal t)
-deriving instance Show (Sparse t t')
-
-class ApplySparse f where
- applySparse :: Sparse t t' -> f t -> f t'
-
-instance ApplySparse STy where
- applySparse (SpSparse s) t = STMaybe (applySparse s t)
- applySparse SpAbsent _ = STNil
- applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
- applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
- applySparse SpScal t = t
-
-instance ApplySparse SMTy where
- applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
- applySparse SpAbsent _ = SMTNil
- applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
- applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
- applySparse SpScal t = t
-
-
-class IsSubType s where
- type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
- subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
- subtFull :: IsSubTypeSubject s f => f t -> s t t
-
-instance IsSubType (:~:) where
- type IsSubTypeSubject (:~:) f = ()
- subtApply = gcastWith
- subtTrans = trans
- subtFull _ = Refl
-
-instance IsSubType Sparse where
- type IsSubTypeSubject Sparse f = f ~ SMTy
- subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
- subtFull = spDense
-
-spDense :: SMTy t -> Sparse t t
-spDense SMTNil = SpAbsent
-spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
-spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
-spDense (SMTMaybe t) = SpMaybe (spDense t)
-spDense (SMTArr _ t) = SpArr (spDense t)
-spDense (SMTScal _) = SpScal
-
-isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
-isDense SMTNil SpAbsent = Just Refl
-isDense _ SpSparse{} = Nothing
-isDense _ SpAbsent = Nothing
-isDense (SMTPair t1 t2) (SpPair s1 s2)
- | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
- | otherwise = Nothing
-isDense (SMTLEither t1 t2) (SpLEither s1 s2)
- | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
- | otherwise = Nothing
-isDense (SMTMaybe t) (SpMaybe s)
- | Just Refl <- isDense t s = Just Refl
- | otherwise = Nothing
-isDense (SMTArr _ t) (SpArr s)
- | Just Refl <- isDense t s = Just Refl
- | otherwise = Nothing
-isDense (SMTScal _) SpScal = Just Refl
-
-isAbsent :: Sparse t t' -> Bool
-isAbsent (SpSparse s) = isAbsent s
-isAbsent SpAbsent = True
-isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
-isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
-isAbsent (SpMaybe s) = isAbsent s
-isAbsent (SpArr s) = isAbsent s
-isAbsent SpScal = False
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
deleted file mode 100644
index 267dd87..0000000
--- a/src/AST/SplitLets.hs
+++ /dev/null
@@ -1,191 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-module AST.SplitLets (splitLets) where
-
-import Data.Type.Equality
-
-import AST
-import AST.Bindings
-import Lemmas
-
-
-splitLets :: Ex env t -> Ex env t
-splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
-
-splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
-splitLets' = \sub -> \case
- EVar _ t i -> sub t i WId
- ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
- ECase x e a b ->
- let STEither t1 t2 = typeOf e
- in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
- EMaybe x a b e ->
- let STMaybe t1 = typeOf e
- in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
- ELCase x e a b c ->
- let STLEither t1 t2 = typeOf e
- in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
- EFold1Inner x cm a b c ->
- let STArr _ t1 = typeOf c
- in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
- EFold1InnerD1 x cm a b c ->
- let STArr _ t1 = typeOf c
- in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
- EFold1InnerD2 x cm a b c ->
- let STArr _ tB = typeOf b
- STArr _ t2 = typeOf c
- in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c)
-
- EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
- EFst x e -> EFst x (splitLets' sub e)
- ESnd x e -> ESnd x (splitLets' sub e)
- ENil x -> ENil x
- EInl x t e -> EInl x t (splitLets' sub e)
- EInr x t e -> EInr x t (splitLets' sub e)
- ENothing x t -> ENothing x t
- EJust x e -> EJust x (splitLets' sub e)
- ELNil x t1 t2 -> ELNil x t1 t2
- ELInl x t e -> ELInl x t (splitLets' sub e)
- ELInr x t e -> ELInr x t (splitLets' sub e)
- EConstArr x n t a -> EConstArr x n t a
- EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
- EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b)
- ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
- EUnit x e -> EUnit x (splitLets' sub e)
- EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
- EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
- EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
- EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
- EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b)
- EConst x t v -> EConst x t v
- EIdx0 x e -> EIdx0 x (splitLets' sub e)
- EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
- EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es)
- EShape x e -> EShape x (splitLets' sub e)
- EOp x op e -> EOp x op (splitLets' sub e)
- ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
- ERecompute x e -> ERecompute x (splitLets' sub e)
- EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
- EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
- EZero x t ezi -> EZero x t (splitLets' sub ezi)
- EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)
- EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
- EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
- EError x t s -> EError x t s
- where
- sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
- -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
- sinkF _ t IZ w = EVar ext t (w @> IZ)
- sinkF f t (IS i) w = f t i (w .> WSink)
-
- split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
- -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t
- split1 sub (tbind :: STy bind) body =
- let (ptrs, bs) = split tbind
- in letBinds bs $
- splitLets' (\cases _ IZ w -> subPointers ptrs w
- t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w)))
- body
-
- split2 :: forall bind1 bind2 env' env t.
- (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
- -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
- split2 sub tbind1 tbind2 body =
- let (ptrs1', bs1') = split @env' tbind1
- bs1 = fst (weakenBindingsE WSink bs1')
- (ptrs2, bs2) = split @(bind1 : env') tbind2
- in letBinds bs1 $
- letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
- splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
- _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env')))
- t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
- body
-
- -- TODO: abstract this to splitN lol wtf
- _split4 :: forall bind1 bind2 bind3 bind4 env' env t.
- (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
- -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
- _split4 sub tbind1 tbind2 tbind3 tbind4 body =
- let (ptrs1, bs1') = split @env' tbind1
- (ptrs2, bs2') = split @(bind1 : env') tbind2
- (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
- (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4
- bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1')
- bs2 = fst (weakenBindingsE (WSink .> WSink) bs2')
- bs3 = fst (weakenBindingsE WSink bs3')
- b1 = bindingsBinds bs1
- b2 = bindingsBinds bs2
- b3 = bindingsBinds bs3
- b4 = bindingsBinds bs4
- in letBinds bs1 $
- letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $
- letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $
- letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $
- splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1))
- _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink))
- _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink))
- _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env')))
- t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w)))))))))
- body
-
-type family Split t where
- Split (TPair a b) = SplitRec (TPair a b)
- Split _ = '[]
-
-type family SplitRec t where
- SplitRec TNil = '[]
- SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a)
- SplitRec t = '[t]
-
-data Pointers env t where
- Point :: STy t -> Idx env t -> Pointers env t
- PNil :: Pointers env TNil
- PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b)
- PWeak :: env' :> env -> Pointers env' t -> Pointers env t
-
-subPointers :: Pointers env t -> env :> env' -> Ex env' t
-subPointers (Point t i) w = EVar ext t (w @> i)
-subPointers PNil _ = ENil ext
-subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w)
-subPointers (PWeak w' p) w = subPointers p (w .> w')
-
-split :: forall env t. STy t
- -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t))
-split typ = case typ of
- STPair{} -> splitRec (EVar ext typ IZ) typ
- STNil -> other
- STEither{} -> other
- STLEither{} -> other
- STMaybe{} -> other
- STArr{} -> other
- STScal{} -> other
- STAccum{} -> other
- where
- other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
- other = (Point typ IZ, BTop)
-
-splitRec :: forall env t. Ex env t -> STy t
- -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t))
-splitRec rhs typ = case typ of
- STNil -> (PNil, BTop)
- STPair (a :: STy a) (b :: STy b)
- | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env ->
- let (p1, bs1) = splitRec (EFst ext rhs) a
- (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
- in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
- STEither{} -> other
- STLEither{} -> other
- STMaybe{} -> other
- STArr{} -> other
- STScal{} -> other
- STAccum{} -> other
- where
- other :: (Pointers (t : env) t, Bindings Ex env '[t])
- other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
deleted file mode 100644
index 4ddcb50..0000000
--- a/src/AST/Types.hs
+++ /dev/null
@@ -1,215 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeData #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.Types where
-
-import Data.Int (Int32, Int64)
-import Data.GADT.Compare
-import Data.GADT.Show
-import Data.Kind (Type)
-import Data.Type.Equality
-
-import Data
-
-
-type data Ty
- = TNil
- | TPair Ty Ty
- | TEither Ty Ty
- | TLEither Ty Ty
- | TMaybe Ty
- | TArr Nat Ty -- ^ rank, element type
- | TScal ScalTy
- | TAccum Ty -- ^ contained type must be a monoid type
-
-type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
-
-type STy :: Ty -> Type
-data STy t where
- STNil :: STy TNil
- STPair :: STy a -> STy b -> STy (TPair a b)
- STEither :: STy a -> STy b -> STy (TEither a b)
- STLEither :: STy a -> STy b -> STy (TLEither a b)
- STMaybe :: STy a -> STy (TMaybe a)
- STArr :: SNat n -> STy t -> STy (TArr n t)
- STScal :: SScalTy t -> STy (TScal t)
- STAccum :: SMTy t -> STy (TAccum t)
-deriving instance Show (STy t)
-
-instance GCompare STy where
- gcompare = \cases
- STNil STNil -> GEQ
- STNil _ -> GLT ; _ STNil -> GGT
- (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- STPair{} _ -> GLT ; _ STPair{} -> GGT
- (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- STEither{} _ -> GLT ; _ STEither{} -> GGT
- (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- STLEither{} _ -> GLT ; _ STLEither{} -> GGT
- (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
- STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
- (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
- STArr{} _ -> GLT ; _ STArr{} -> GGT
- (STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
- STScal{} _ -> GLT ; _ STScal{} -> GGT
- (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
-
-instance TestEquality STy where testEquality = geq
-instance GEq STy where geq = defaultGeq
-instance GShow STy where gshowsPrec = defaultGshowsPrec
-
--- | Monoid types
-type SMTy :: Ty -> Type
-data SMTy t where
- SMTNil :: SMTy TNil
- SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b)
- SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b)
- SMTMaybe :: SMTy a -> SMTy (TMaybe a)
- SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
- SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
-deriving instance Show (SMTy t)
-
-instance GCompare SMTy where
- gcompare = \cases
- SMTNil SMTNil -> GEQ
- SMTNil _ -> GLT ; _ SMTNil -> GGT
- (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT
- (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT
- (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a')
- SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT
- (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
- SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
- (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
- -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
-
-instance TestEquality SMTy where testEquality = geq
-instance GEq SMTy where geq = defaultGeq
-instance GShow SMTy where gshowsPrec = defaultGshowsPrec
-
-fromSMTy :: SMTy t -> STy t
-fromSMTy = \case
- SMTNil -> STNil
- SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2)
- SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2)
- SMTMaybe t -> STMaybe (fromSMTy t)
- SMTArr n t -> STArr n (fromSMTy t)
- SMTScal sty -> STScal sty
-
-data SScalTy t where
- STI32 :: SScalTy TI32
- STI64 :: SScalTy TI64
- STF32 :: SScalTy TF32
- STF64 :: SScalTy TF64
- STBool :: SScalTy TBool
-deriving instance Show (SScalTy t)
-
-instance GCompare SScalTy where
- gcompare = \cases
- STI32 STI32 -> GEQ
- STI32 _ -> GLT ; _ STI32 -> GGT
- STI64 STI64 -> GEQ
- STI64 _ -> GLT ; _ STI64 -> GGT
- STF32 STF32 -> GEQ
- STF32 _ -> GLT ; _ STF32 -> GGT
- STF64 STF64 -> GEQ
- STF64 _ -> GLT ; _ STF64 -> GGT
- STBool STBool -> GEQ
- -- STBool _ -> GLT ; _ STBool -> GGT
-
-instance TestEquality SScalTy where testEquality = geq
-instance GEq SScalTy where geq = defaultGeq
-instance GShow SScalTy where gshowsPrec = defaultGshowsPrec
-
-scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
-scalRepIsShow STI32 = Dict
-scalRepIsShow STI64 = Dict
-scalRepIsShow STF32 = Dict
-scalRepIsShow STF64 = Dict
-scalRepIsShow STBool = Dict
-
-type TIx = TScal TI64
-
-tIx :: STy TIx
-tIx = STScal STI64
-
-type family ScalRep t where
- ScalRep TI32 = Int32
- ScalRep TI64 = Int64
- ScalRep TF32 = Float
- ScalRep TF64 = Double
- ScalRep TBool = Bool
-
-type family ScalIsNumeric t where
- ScalIsNumeric TI32 = True
- ScalIsNumeric TI64 = True
- ScalIsNumeric TF32 = True
- ScalIsNumeric TF64 = True
- ScalIsNumeric TBool = False
-
-type family ScalIsFloating t where
- ScalIsFloating TI32 = False
- ScalIsFloating TI64 = False
- ScalIsFloating TF32 = True
- ScalIsFloating TF64 = True
- ScalIsFloating TBool = False
-
-type family ScalIsIntegral t where
- ScalIsIntegral TI32 = True
- ScalIsIntegral TI64 = True
- ScalIsIntegral TF32 = False
- ScalIsIntegral TF64 = False
- ScalIsIntegral TBool = False
-
--- | Returns true for arrays /and/ accumulators.
-typeHasArrays :: STy t' -> Bool
-typeHasArrays STNil = False
-typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
-typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
-typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
-typeHasArrays (STMaybe t) = typeHasArrays t
-typeHasArrays STArr{} = True
-typeHasArrays STScal{} = False
-typeHasArrays STAccum{} = True
-
-typeHasAccums :: STy t' -> Bool
-typeHasAccums STNil = False
-typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
-typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
-typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
-typeHasAccums (STMaybe t) = typeHasAccums t
-typeHasAccums STArr{} = False
-typeHasAccums STScal{} = False
-typeHasAccums STAccum{} = True
-
-type family Tup env where
- Tup '[] = TNil
- Tup (t : ts) = TPair (Tup ts) t
-
-mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
- -> SList f list -> f (Tup list)
-mkTup nil _ SNil = nil
-mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
-
-tTup :: SList STy env -> STy (Tup env)
-tTup = mkTup STNil STPair
-
-unTup :: (forall a b. c (TPair a b) -> (c a, c b))
- -> SList f list -> c (Tup list) -> SList c list
-unTup _ SNil _ = SNil
-unTup unpack (_ `SCons` list) tup =
- let (xs, x) = unpack tup
- in x `SCons` unTup unpack list xs
-
-type family InvTup core env where
- InvTup core '[] = core
- InvTup core (t : ts) = InvTup (TPair core t) ts
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
deleted file mode 100644
index 1712ba5..0000000
--- a/src/AST/UnMonoid.hs
+++ /dev/null
@@ -1,255 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
-
-import AST
-import AST.Sparse.Types
-import Data
-
-
--- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
--- expanding them into their concrete implementations. Also ensure that
--- 'EAccum' has a dense sparsity.
-unMonoid :: Ex env t -> Ex env t
-unMonoid = \case
- EZero _ t e -> zero t e
- EDeepZero _ t e -> deepZero t e
- EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
- EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
-
- EVar _ t i -> EVar ext t i
- ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
- EPair _ a b -> EPair ext (unMonoid a) (unMonoid b)
- EFst _ e -> EFst ext (unMonoid e)
- ESnd _ e -> ESnd ext (unMonoid e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (unMonoid e)
- EInr _ t e -> EInr ext t (unMonoid e)
- ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b)
- ENothing _ t -> ENothing ext t
- EJust _ e -> EJust ext (unMonoid e)
- EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e)
- ELNil _ t1 t2 -> ELNil ext t1 t2
- ELInl _ t e -> ELInl ext t (unMonoid e)
- ELInr _ t e -> ELInr ext t (unMonoid e)
- ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
- EConstArr _ n t x -> EConstArr ext n t x
- EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
- EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
- EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
- EUnit _ e -> EUnit ext (unMonoid e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
- EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
- EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
- EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
- EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
- EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
- EConst _ t x -> EConst ext t x
- EIdx0 _ e -> EIdx0 ext (unMonoid e)
- EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
- EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b)
- EShape _ e -> EShape ext (unMonoid e)
- EOp _ op e -> EOp ext op (unMonoid e)
- ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
- ERecompute _ e -> ERecompute ext (unMonoid e)
- EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p eidx sp eval eacc ->
- accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
- acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
- EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
- EError _ t s -> EError ext t s
-
-zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
--- don't destroy the effects!
-zero SMTNil e = ELet ext e $ ENil ext
-zero (SMTPair t1 t2) e =
- ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
- (zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
-zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2)
-zero (SMTMaybe t) _ = ENothing ext (fromSMTy t)
-zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e
-zero (SMTScal t) _ = case t of
- STI32 -> EConst ext STI32 0
- STI64 -> EConst ext STI64 0
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
-
-deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
-deepZero SMTNil e = elet e $ ENil ext
-deepZero (SMTPair t1 t2) e =
- ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
- (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
-deepZero (SMTLEither t1 t2) e =
- elcase e
- (ELNil ext (fromSMTy t1) (fromSMTy t2))
- (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
- (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
-deepZero (SMTMaybe t) e =
- emaybe e
- (ENothing ext (fromSMTy t))
- (EJust ext (deepZero t (evar IZ)))
-deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
-deepZero (SMTScal t) _ = case t of
- STI32 -> EConst ext STI32 0
- STI64 -> EConst ext STI64 0
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
-
-plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
--- don't destroy the effects!
-plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
-plus (SMTPair t1 t2) a b =
- let t = STPair (fromSMTy t1) (fromSMTy t2)
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
- (EFst ext (EVar ext t IZ)))
- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
- (ESnd ext (EVar ext t IZ)))
-plus (SMTLEither t1 t2) a b =
- let t = STLEither (fromSMTy t1) (fromSMTy t2)
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
- ELCase ext (EVar ext t (IS IZ))
- (EVar ext t IZ)
- (ELCase ext (EVar ext t (IS IZ))
- (EVar ext t (IS (IS IZ)))
- (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ)))
- (EError ext t "plus l+r"))
- (ELCase ext (EVar ext t (IS IZ))
- (EVar ext t (IS (IS IZ)))
- (EError ext t "plus r+l")
- (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ))))
-plus (SMTMaybe t) a b =
- ELet ext b $
- EMaybe ext
- (EVar ext (STMaybe (fromSMTy t)) IZ)
- (EJust ext
- (EMaybe ext
- (EVar ext (fromSMTy t) IZ)
- (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
- (EVar ext (STMaybe (fromSMTy t)) (IS IZ))))
- (weakenExpr WSink a)
-plus (SMTArr _ t) a b =
- ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
- a b
-plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
-onehot typ topprj idx arg = case (typ, topprj) of
- (_, SAPHere) ->
- ELet ext arg $
- EVar ext (fromSMTy typ) IZ
-
- (SMTPair t1 t2, SAPFst prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t1 in
- EPair ext (EVar ext toh IZ)
- (zero t2 (ESnd ext (EVar ext tidx (IS IZ))))
-
- (SMTPair t1 t2, SAPSnd prj) ->
- ELet ext idx $
- let tidx = typeOf idx in
- ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $
- let toh = fromSMTy t2 in
- EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ))))
- (EVar ext toh IZ)
-
- (SMTLEither t1 t2, SAPLeft prj) ->
- ELInl ext (fromSMTy t2) (onehot t1 prj idx arg)
- (SMTLEither t1 t2, SAPRight prj) ->
- ELInr ext (fromSMTy t1) (onehot t2 prj idx arg)
-
- (SMTMaybe t1, SAPJust prj) ->
- EJust ext (onehot t1 prj idx arg)
-
- (SMTArr n t1, SAPArrIdx prj) ->
- let tidx = tTup (sreplicate n tIx)
- in ELet ext idx $
- EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $
- eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
- (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
- (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
- zero t1 (EVar ext (tZeroInfo t1) IZ))
-
-accumulateSparse
- :: SMTy t -> Sparse t t' -> Ex env t'
- -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
- -> Ex env TNil
-accumulateSparse topty topsp arg accum = case (topty, topsp) of
- (_, s) | Just Refl <- isDense topty s ->
- accum WId SAPHere (ENil ext) arg
- (SMTScal _, SpScal) ->
- accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
- (_, SpSparse s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
- (_, SpAbsent) ->
- ENil ext
- (SMTPair t1 t2, SpPair s1 s2) ->
- eunPair arg $ \w1 e1 e2 ->
- elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
- accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
- (SMTLEither t1 t2, SpLEither s1 s2) ->
- elcase arg
- (ENil ext)
- (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
- (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
- (SMTMaybe t, SpMaybe s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
- (SMTArr n t, SpArr s) ->
- let tn = tTup (sreplicate n tIx) in
- elet arg $
- elet (EBuild ext n (EShape ext (evar IZ)) $
- accumulateSparse t s
- (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
- (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
- ENil ext
-
-acPrjCompose
- :: SAIDense dense
- -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
- -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
-acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
-acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
- acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPFst p') idx'
-acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
- acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPSnd p') idx'
-acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
-acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
-acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
- acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPLeft p') idx'
-acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
- acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPRight p') idx'
-acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
- acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
- k (SAPJust p') idx'
-acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
-acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
- | Dict <- styKnown (typeOf idx1) =
- acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
- k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
deleted file mode 100644
index f0820b8..0000000
--- a/src/AST/Weaken.hs
+++ /dev/null
@@ -1,138 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-
-{-# LANGUAGE PartialTypeSignatures #-}
-{-# OPTIONS -Wno-partial-type-signatures #-}
-
--- The reason why this is a separate module with "little" in it:
-{-# LANGUAGE AllowAmbiguousTypes #-}
-
-module AST.Weaken (module AST.Weaken, Append) where
-
-import Data.Bifunctor (first)
-import Data.Functor.Const
-import Data.GADT.Compare
-import Data.Kind (Type)
-
-import Data
-import Lemmas
-
-
-type Idx :: [k] -> k -> Type
-data Idx env t where
- IZ :: Idx (t : env) t
- IS :: Idx env t -> Idx (a : env) t
-deriving instance Show (Idx env t)
-
-instance GEq (Idx env) where
- geq IZ IZ = Just Refl
- geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl
- geq _ _ = Nothing
-
-splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
-splitIdx SNil i = Right i
-splitIdx (SCons _ _) IZ = Left IZ
-splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
-
-slistIdx :: SList f list -> Idx list t -> f t
-slistIdx (SCons x _) IZ = x
-slistIdx (SCons _ list) (IS i) = slistIdx list i
-slistIdx SNil i = case i of {}
-
-idx2int :: Idx env t -> Int
-idx2int IZ = 0
-idx2int (IS n) = 1 + idx2int n
-
-data env :> env' where
- WId :: env :> env
- WSink :: forall t env. env :> (t : env)
- WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env')
- WPop :: (t : env) :> env' -> env :> env'
- WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
- WClosed :: '[] :> env
- WIdx :: Idx env t -> (t : env) :> env
- WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env'
- -> Append pre (t : env) :> t : Append pre env'
- WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs
- -> Append as (Append bs env) :> Append bs (Append as env)
- WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs
- -> as :> bs -> env1 :> env2
- -> Append as env1 :> Append bs env2
-deriving instance Show (env :> env')
-infix 4 :>
-
-infixr 2 @>
-(@>) :: env :> env' -> Idx env t -> Idx env' t
-WId @> i = i
-WSink @> i = IS i
-WCopy _ @> IZ = IZ
-WCopy w @> IS i = IS (w @> i)
-WPop w @> i = w @> IS i
-WThen w1 w2 @> i = w2 @> w1 @> i
-WClosed @> i = case i of {}
-WIdx j @> IZ = j
-WIdx _ @> IS i = i
-WPick SNil w @> i = WCopy w @> i
-WPick (_ `SCons` _) _ @> IZ = IS IZ
-WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i
-WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i =
- case splitIdx @(Append bs env) as i of
- Left i' -> indexSinks bs (indexRaiseAbove @env as i')
- Right i' -> case splitIdx @env bs i' of
- Left j -> indexRaiseAbove @(Append as env) bs j
- Right j -> indexSinks bs (indexSinks as j)
-WStack @env1 @env2 as bs wlo whi @> i =
- case splitIdx @env1 as i of
- Left i' -> indexRaiseAbove @env2 bs (wlo @> i')
- Right i' -> indexSinks bs (whi @> i')
-
-indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t
-indexSinks SNil j = j
-indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j)
-
-indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t
-indexRaiseAbove = flip go
- where
- go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t
- go IZ (_ `SCons` _) = IZ
- go (IS i) (_ `SCons` as) = IS (go i as)
-
-infixr 3 .>
-(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
-(.>) = flip WThen
-
-class KnownListSpine list where knownListSpine :: SList (Const ()) list
-instance KnownListSpine '[] where knownListSpine = SNil
-instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine
-
-wSinks' :: forall list env. KnownListSpine list => env :> Append list env
-wSinks' = wSinks (knownListSpine :: SList (Const ()) list)
-
-wSinks :: forall env bs f. SList f bs -> env :> Append bs env
-wSinks SNil = WId
-wSinks (SCons _ spine) = WSink .> wSinks spine
-
-wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env'
-wSinksAnd SNil w = w
-wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w
-
-wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
-wCopies bs w =
- let bs' = slistMap (\_ -> Const ()) bs
- in WStack bs' bs' WId w
-
-wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env
-wRaiseAbove SNil _ = WClosed
-wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
-
-wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2
-wPops SNil w = w
-wPops (_ `SCons` bs) w = wPops bs (WPop w)
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
deleted file mode 100644
index 7370df1..0000000
--- a/src/AST/Weaken/Auto.hs
+++ /dev/null
@@ -1,192 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE FunctionalDependencies #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE KindSignatures #-}
-{-# LANGUAGE MultiParamTypeClasses #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-
-{-# LANGUAGE AllowAmbiguousTypes #-}
-
-{-# LANGUAGE PartialTypeSignatures #-}
-{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
-module AST.Weaken.Auto (
- autoWeak,
- (&.), auto, auto1,
- Layout(..),
-) where
-
-import Data.Functor.Const
-import Data.Kind (Constraint)
-import GHC.OverloadedLabels
-import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
-
-import AST.Weaken
-import Data
-import Lemmas
-
-
-type family Lookup name list where
- Lookup name ('(name, x) : _) = x
- Lookup name (_ : list) = Lookup name list
- Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.")
-
-
--- | The @withPre@ type parameter indicates whether there can be 'LPreW'
--- occurrences within this layout. 'names' is the list of names that this
--- layout /produces/. That is: for LPreW, it contains the target name. The
--- 'names' list of a source layout must be a subset of the names list of the
--- target layout (which cannot contain LPreW); this is checked with SubLayout.
-data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where
- LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments)
- -- | Pre-weaken with a weakening
- LPreW :: forall name1 name2 segments.
- SegmentName name1 -> SegmentName name2
- -> Lookup name1 segments :> Lookup name2 segments
- -> Layout True segments '[name2] (Lookup name1 segments)
- (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2)
-infixr :++:
-
-instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where
- fromLabel = LSeg (symbolSing @name)
-
-newtype SegmentName name = SegmentName (SSymbol name)
- deriving (Show)
-
-instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where
- fromLabel = SegmentName symbolSing
-
-
-type family SubLayout names1 names2 where
- SubLayout '[] _ = () :: Constraint
- SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2
-type family SubLayout' n ok names1 names2 where
- SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.")
- SubLayout' _ True names1 names2 = SubLayout names1 names2
-type family Contains n names where
- Contains _ '[] = False
- Contains n (n : _) = True
- Contains n (_ : names) = Contains n names
-
-
-data SSegments (segments :: [(Symbol, [t])]) where
- SSegNil :: SSegments '[]
- SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
-
-instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where
- fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil
-
-auto :: KnownListSpine list => SList (Const ()) list
-auto = knownListSpine
-
-auto1 :: SList (Const ()) '[t]
-auto1 = Const () `SCons` SNil
-
-infixr &.
-(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2)
-(&.) = ssegmentsAppend
- where
- ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
- ssegmentsAppend SSegNil l2 = l2
- ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2)
-
-
--- | If the found segment is a TopSeg, returns Nothing.
-segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
-segmentLookup = \segs name -> case go segs name of
- Just ts -> ts
- Nothing -> error $ "Segment not found: " ++ fromSSymbol name
- where
- go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs'))
- go SSegNil _ = Nothing
- go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol =
- case sameSymbol n name of
- Just Refl ->
- case go sseg name of
- Nothing -> Just ts
- Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name
- Nothing ->
- case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of
- Refl -> go sseg name
-
-data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where
- LinEnd :: LinLayout withPre segments '[]
- LinApp :: SSymbol name -> LinLayout withPre segments env
- -> LinLayout withPre segments (Append (Lookup name segments) env)
- LinAppPreW :: SSymbol name1 -> SSymbol name2
- -> Lookup name1 segments :> Lookup name2 segments
- -> LinLayout True segments env
- -> LinLayout True segments (Append (Lookup name1 segments) env)
-
-linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2)
-linLayoutAppend LinEnd lin = lin
-linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)
- | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2
- = LinApp name (linLayoutAppend lin1 lin2)
-linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)
- | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2
- = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2)
-
-lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env
-lineariseLayout (LSeg name :: Layout _ _ _ seg)
- | Refl <- lemAppendNil @seg
- = LinApp name LinEnd
-lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
-lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg)
- | Refl <- lemAppendNil @seg
- = LinAppPreW name1 name2 w LinEnd
-
-preWeaken :: SSegments segments -> LinLayout True segments env
- -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r
-preWeaken _ LinEnd k = k WId LinEnd
-preWeaken segs (LinApp name lin) k =
- preWeaken segs lin $ \w lin' ->
- k (wCopies (segmentLookup segs name) w) (LinApp name lin')
-preWeaken segs (LinAppPreW name1 name2 weak lin) k =
- preWeaken segs lin $ \w lin' ->
- k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin')
-
-pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env
- -> r -- Name was not found in source
- -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r)
- -> r
-pullDown segs name@SSymbol linlayout kNotFound k =
- case linlayout of
- LinEnd -> kNotFound
- LinApp n'@SSymbol lin
- | Just Refl <- sameSymbol name n' -> k lin WId
- | otherwise ->
- pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w ->
- k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name)
- .> wCopies (segmentLookup segs n') w)
-
-sortLinLayouts :: SSegments segments
- -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2
-sortLinLayouts _ LinEnd LinEnd = WId
-sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
- | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2)
- | otherwise =
- pullDown segs name2 lin1
- (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2)
- (\tail1' w ->
- -- We've pulled down name2 in lin1 so that it's at the head; the
- -- resulting modified tail is tail1'. Thus now we have (name2 : tail1')
- -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
- -- wCopies the name2 on top of that.
- wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
-sortLinLayouts _ LinEnd LinApp{} = WClosed
-sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target"
-
-autoWeak :: SubLayout names1 names2
- => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2
-autoWeak segs ly1 ly2 =
- preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 ->
- sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak