diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 137 | ||||
| -rw-r--r-- | src/AST/Bindings.hs | 84 | ||||
| -rw-r--r-- | src/AST/Count.hs | 930 | ||||
| -rw-r--r-- | src/AST/Env.hs | 95 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 525 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 287 | ||||
| -rw-r--r-- | src/AST/Sparse/Types.hs | 107 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 191 | ||||
| -rw-r--r-- | src/AST/Types.hs | 215 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 255 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 138 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 192 |
12 files changed, 0 insertions, 3156 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs deleted file mode 100644 index 988a450..0000000 --- a/src/AST/Accum.hs +++ /dev/null @@ -1,137 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeData #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.Accum where - -import AST.Types -import Data - - -data AcPrj - = APHere - | APFst AcPrj - | APSnd AcPrj - | APLeft AcPrj - | APRight AcPrj - | APJust AcPrj - | APArrIdx AcPrj - | APArrSlice Nat - --- | @b@ is a small part of @a@, indicated by the projection @p@. -data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where - SAPHere :: SAcPrj APHere a a - SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b - SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b - SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b - -- TODO: - -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) -deriving instance Show (SAcPrj p a b) - -type data AIDense = AID | AIS - -data SAIDense d where - SAID :: SAIDense AID - SAIS :: SAIDense AIS -deriving instance Show (SAIDense d) - -type family AcIdx d p t where - AcIdx d APHere t = TNil - AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a - AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b - AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) - AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) - AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a - AcIdx d (APRight p) (TLEither a b) = AcIdx d p b - AcIdx d (APJust p) (TMaybe a) = AcIdx d p a - AcIdx AID (APArrIdx p) (TArr n a) = - -- (index, recursive info) - TPair (Tup (Replicate n TIx)) (AcIdx AID p a) - AcIdx AIS (APArrIdx p) (TArr n a) = - -- ((index, shape info), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx AIS p a) - -- AcIdx AID (APArrSlice m) (TArr n a) = - -- -- index - -- Tup (Replicate m TIx) - -- AcIdx AIS (APArrSlice m) (TArr n a) = - -- -- (index, array shape) - -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) - -type AcIdxD p t = AcIdx AID p t -type AcIdxS p t = AcIdx AIS p t - -acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b -acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t - -type family ZeroInfo t where - ZeroInfo TNil = TNil - ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) - ZeroInfo (TLEither a b) = TNil - ZeroInfo (TMaybe a) = TNil - ZeroInfo (TArr n t) = TArr n (ZeroInfo t) - ZeroInfo (TScal t) = TNil - -tZeroInfo :: SMTy t -> STy (ZeroInfo t) -tZeroInfo SMTNil = STNil -tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) -tZeroInfo (SMTLEither _ _) = STNil -tZeroInfo (SMTMaybe _) = STNil -tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) -tZeroInfo (SMTScal _) = STNil - --- | Info needed to create a zero-valued deep accumulator for a monoid type. --- Should be constructable from a D1. -type family DeepZeroInfo t where - DeepZeroInfo TNil = TNil - DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) - DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) - DeepZeroInfo (TScal t) = TNil - -tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) -tDeepZeroInfo SMTNil = STNil -tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) -tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) -tDeepZeroInfo (SMTScal _) = STNil - --- -- | Additional info needed for accumulation. This is empty unless there is --- -- sparsity in the monoid. --- type family AccumInfo t where --- AccumInfo TNil = TNil --- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) --- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) --- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) --- AccumInfo (TArr n t) = TArr n (AccumInfo t) --- AccumInfo (TScal t) = TNil - --- type family PrimalInfo t where --- PrimalInfo TNil = TNil --- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) --- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) --- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) --- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) --- PrimalInfo (TScal t) = TNil - --- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) --- tPrimalInfo SMTNil = STNil --- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) --- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) --- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) --- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) --- tPrimalInfo (SMTScal _) = STNil diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs deleted file mode 100644 index 463586a..0000000 --- a/src/AST/Bindings.hs +++ /dev/null @@ -1,84 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module AST.Bindings where - -import AST -import AST.Env -import Data -import Lemmas - - --- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. -data Bindings f env binds where - BTop :: Bindings f env '[] - BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) -deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') -infixl `BPush` - -bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) -bpush b e = b `BPush` (typeOf e, e) -infixl `bpush` - -mapBindings :: (forall env' t'. f env' t' -> g env' t') - -> Bindings f env binds -> Bindings g env binds -mapBindings _ BTop = BTop -mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) - -weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) - -> env1 :> env2 - -> Bindings f env1 binds - -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) -weakenBindings _ w BTop = (BTop, w) -weakenBindings wf w (BPush b (t, x)) = - let (b', w') = weakenBindings wf w b - in (BPush b' (t, wf w' x), WCopy w') - -weakenBindingsE :: env1 :> env2 - -> Bindings (Expr x) env1 binds - -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) -weakenBindingsE = weakenBindings weakenExpr - -weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' -weakenOver SNil w = w -weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) - -sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env' -sinkWithBindings BTop = WId -sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b - -bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) -bconcat b1 BTop = b1 -bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) - | Refl <- lemAppendAssoc @binds2C @binds1 @env - = BPush (bconcat b1 b2) (t, x) - -bindingsBinds :: Bindings f env binds -> SList STy binds -bindingsBinds BTop = SNil -bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) - -letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t -letBinds BTop = id -letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs - -collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' -collectBindings = \env -> fst . go env WId - where - go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) - go _ _ SETop = (BTop, WId) - go (ty `SCons` env) w (SEYesR sub) = - let (bs, w') = go env (WPop w) sub - in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') - go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index ac8634e..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,930 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE PatternSynonyms #-} -module AST.Count where - -import Data.Functor.Product -import Data.Some -import Data.Type.Equality -import GHC.Generics (Generic, Generically(..)) - -import Array -import AST -import AST.Env -import Data - - --- | The monoid operation combines assuming that /both/ branches are taken. -class Monoid a => Occurrence a where - -- | One of the two branches is taken - (<||>) :: a -> a -> a - -- | This code is executed many times - scaleMany :: a -> a - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero -instance Occurrence Count where - (<||>) = max - scaleMany Zero = Zero - scaleMany _ = Many - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - -instance Occurrence Occ where - Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) - scaleMany (Occ l c) = Occ l (scaleMany c) - - -data Substruc t t' where - -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below - SsFull :: Substruc t t - SsNone :: Substruc t TNil - SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') - SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') - SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') - SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') - SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements - SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') - -pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' -pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) - where SsPair' = SsPair -{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} - -pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' -pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) - where SsArr' = SsArr -{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} - -instance Semigroup (Some (Substruc t)) where - Some SsFull <> _ = Some SsFull - _ <> Some SsFull = Some SsFull - Some SsNone <> s = s - s <> Some SsNone = s - Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) - Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) - Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) - Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) - Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) - Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) -instance Monoid (Some (Substruc t)) where - mempty = Some SsNone - -instance TestEquality (Substruc t) where - testEquality SsFull s = isFull s - testEquality s SsFull = sym <$> isFull s - testEquality SsNone SsNone = Just Refl - testEquality SsNone _ = Nothing - testEquality _ SsNone = Nothing - testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - -isFull :: Substruc t t' -> Maybe (t :~: t') -isFull SsFull = Just Refl -isFull SsNone = Nothing -- TODO: nil? -isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing -isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing -isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing - -applySubstruc :: Substruc t t' -> STy t -> STy t' -applySubstruc SsFull t = t -applySubstruc SsNone _ = STNil -applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) -applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) -applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) - -applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' -applySubstrucM SsFull t = t -applySubstrucM SsNone _ = SMTNil -applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) -applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) -applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) -applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) -applySubstrucM _ t = case t of {} - -data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) - | a ~ b => ExMapId - -fromExMap :: ExMap a b -> Ex env a -> Ex env b -fromExMap (ExMap f) = f -fromExMap ExMapId = id - -simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' -simplifySubstruc STNil SsNone = SsFull - -simplifySubstruc _ SsFull = SsFull -simplifySubstruc _ SsNone = SsNone -simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) -simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) -simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) - --- simplifySubstruc' :: Substruc t t' --- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r --- simplifySubstruc' SsFull k = k SsFull ExMapId --- simplifySubstruc' SsNone k = k SsNone ExMapId --- simplifySubstruc' (SsPair s1 s2) k = --- simplifySubstruc' s1 $ \s1' f1 -> --- simplifySubstruc' s2 $ \s2' f2 -> --- case (s1', s2') of --- (SsFull, SsFull) -> --- k SsFull (case (f1, f2) of --- (ExMapId, ExMapId) -> ExMapId --- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> --- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) --- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) --- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) --- simplifySubstruc' _ _ = _ - --- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) --- ssUnpair SsFull = (SsFull, SsFull) --- ssUnpair SsNone = (SsNone, SsNone) --- ssUnpair (SsPair a b) = (a, b) - --- ssUnleft :: Substruc (TEither a b) -> Substruc a --- ssUnleft SsFull = SsFull --- ssUnleft SsNone = SsNone --- ssUnleft (SsEither a _) = a - --- ssUnright :: Substruc (TEither a b) -> Substruc b --- ssUnright SsFull = SsFull --- ssUnright SsNone = SsNone --- ssUnright (SsEither _ b) = b - --- ssUnlleft :: Substruc (TLEither a b) -> Substruc a --- ssUnlleft SsFull = SsFull --- ssUnlleft SsNone = SsNone --- ssUnlleft (SsLEither a _) = a - --- ssUnlright :: Substruc (TLEither a b) -> Substruc b --- ssUnlright SsFull = SsFull --- ssUnlright SsNone = SsNone --- ssUnlright (SsLEither _ b) = b - --- ssUnjust :: Substruc (TMaybe a) -> Substruc a --- ssUnjust SsFull = SsFull --- ssUnjust SsNone = SsNone --- ssUnjust (SsMaybe a) = a - --- ssUnarr :: Substruc (TArr n a) -> Substruc a --- ssUnarr SsFull = SsFull --- ssUnarr SsNone = SsNone --- ssUnarr (SsArr a) = a - --- ssUnaccum :: Substruc (TAccum a) -> Substruc a --- ssUnaccum SsFull = SsFull --- ssUnaccum SsNone = SsNone --- ssUnaccum (SsAccum a) = a - - -type family MapEmpty env where - MapEmpty '[] = '[] - MapEmpty (t : env) = TNil : MapEmpty env - -data OccEnv a env env' where - OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! - OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') - -instance Semigroup a => Semigroup (Some (OccEnv a env)) where - Some OccEnd <> e = e - e <> Some OccEnd = e - Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) - -instance Semigroup a => Monoid (Some (OccEnv a env)) where - mempty = Some OccEnd - -instance Occurrence a => Occurrence (Some (OccEnv a env)) where - Some OccEnd <||> e = e - e <||> Some OccEnd = e - Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) - - scaleMany (Some OccEnd) = Some OccEnd - scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) - -onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) -onehotOccEnv IZ v s = Some (OccPush OccEnd v s) -onehotOccEnv (IS i) v s - | Some env' <- onehotOccEnv i v s - = Some (OccPush env' mempty SsNone) - -occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') -occEnvPop (OccPush e _ s) = (e, s) -occEnvPop OccEnd = (OccEnd, SsNone) - -occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r -occEnvPop' (OccPush e _ s) k = k e s -occEnvPop' OccEnd k = k OccEnd SsNone - -occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) -occEnvPopSome (Some (OccPush e _ _)) = Some e -occEnvPopSome (Some OccEnd) = Some OccEnd - -occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) -occEnvPrj OccEnd _ = mempty -occEnvPrj (OccPush _ o s) IZ = (o, Some s) -occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i - -occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) -occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) -occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) -occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) -occEnvPrjS (OccPush e _ _) (IS i) - | Some (Pair s' i') <- occEnvPrjS e i - = Some (Pair s' (IS i')) - -projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small -projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of - _ | Just Refl <- testEquality topsbig topssmall -> ex - - (SsFull, SsFull) -> ex - (SsNone, SsNone) -> ex - (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" - (_, SsNone) -> - case typeOf ex of - STNil -> ex - _ -> use ex $ ENil ext - - (SsPair s1 s2, SsPair s1' s2') -> - eunPair ex $ \_ e1 e2 -> - EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) - (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex - (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex - - (SsEither s1 s2, SsEither s1' s2') - | STEither t1 t2 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) - in ecase ex - (EInl ext (typeOf e2) e1) - (EInr ext (typeOf e1) e2) - (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex - (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex - - (SsLEither s1 s2, SsLEither s1' s2') - | STLEither t1 t2 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) - in elcase ex - (ELNil ext (typeOf e1) (typeOf e2)) - (ELInl ext (typeOf e2) e1) - (ELInr ext (typeOf e1) e2) - (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex - (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex - - (SsMaybe s1, SsMaybe s1') - | STMaybe t1 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - in emaybe ex - (ENothing ext (typeOf e1)) - (EJust ext e1) - (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex - (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex - - (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex - (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex - (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex - - (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" - (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex - (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex - - --- | A boolean for each entry in the environment, with the ability to uniformly --- mask the top part above a certain index. -data EnvMask env where - EMRest :: Bool -> EnvMask env - EMPush :: EnvMask env -> Bool -> EnvMask (t : env) - -envMaskPrj :: EnvMask env -> Idx env t -> Bool -envMaskPrj (EMRest b) _ = b -envMaskPrj (_ `EMPush` b) IZ = b -envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx ex - | Some env <- occCountAll ex - = fst (occEnvPrj env idx) - -occCountAll :: Expr x env t -> Some (OccEnv Occ env) -occCountAll ex = occCountX SsFull ex $ \env _ -> Some env - -pruneExpr :: SList f env -> Expr x env t -> Ex env t -pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) - where - fullOccEnv :: SList f env -> OccEnv () env env - fullOccEnv SNil = OccEnd - fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull - --- In one traversal, count occurrences of variables and determine what parts of --- expressions are actually used. These two results are computed independently: --- even if (almost) nothing of a particular term is actually used, variable --- references in that term still count as usual. --- --- In @occCountX s t k@: --- * s: how much of the result of this term is required --- * t: the term to analyse --- * k: is passed the actual environment usage of this expression, including --- occurrence counts. The callback reconstructs a new expression in an --- updated "response" environment. The response must be at least as large as --- the computed usages. -occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t - -> (forall env'. OccEnv Occ env env' - -- response OccEnv must be at least as large as the OccEnv returned above - -> (forall env''. OccEnv () env env'' -> Ex env'' t') - -> r) - -> r -occCountX initialS topexpr k = case topexpr of - EVar _ t i -> - withSome (onehotOccEnv i (Occ One One) s) $ \env -> - k env $ \env' -> - withSome (occEnvPrjS env' i) $ \(Pair s' i') -> - projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') - ELet _ rhs body -> - occCountX s body $ \envB mkbody -> - occEnvPop' envB $ \envB' s1 -> - occCountX s1 rhs $ \envR mkrhs -> - withSome (Some envB' <> Some envR) $ \env -> - k env $ \env' -> - ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) - EPair _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsPair' s1 s2 -> - occCountX s1 a $ \env1 mka -> - occCountX s2 b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EPair ext (mka env') (mkb env') - EFst _ e -> - occCountX (SsPair s SsNone) e $ \env1 mke -> - k env1 $ \env' -> - EFst ext (mke env') - ESnd _ e -> - occCountX (SsPair SsNone s) e $ \env1 mke -> - k env1 $ \env' -> - ESnd ext (mke env') - ENil _ -> - case s of - SsFull -> k OccEnd (\_ -> ENil ext) - SsNone -> k OccEnd (\_ -> ENil ext) - EInl _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsEither s1 s2 -> - occCountX s1 e $ \env1 mke -> - k env1 $ \env' -> - EInl ext (applySubstruc s2 t) (mke env') - SsFull -> occCountX (SsEither SsFull SsFull) topexpr k - EInr _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsEither s1 s2 -> - occCountX s2 e $ \env1 mke -> - k env1 $ \env' -> - EInr ext (applySubstruc s1 t) (mke env') - SsFull -> occCountX (SsEither SsFull SsFull) topexpr k - ECase _ e a b -> - occCountX s a $ \env1' mka -> - occCountX s b $ \env2' mkb -> - occEnvPop' env1' $ \env1 s1 -> - occEnvPop' env2' $ \env2 s2 -> - occCountX (SsEither s1 s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> - k env $ \env' -> - ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) - ENothing _ t -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) - SsFull -> occCountX (SsMaybe SsFull) topexpr k - EJust _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsMaybe s' -> - occCountX s' e $ \env1 mke -> - k env1 $ \env' -> - EJust ext (mke env') - SsFull -> occCountX (SsMaybe SsFull) topexpr k - EMaybe _ a b e -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s2 -> - occCountX (SsMaybe s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> - k env $ \env' -> - EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') - ELNil _ t1 t2 -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELInl _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsLEither s1 s2 -> - occCountX s1 e $ \env1 mke -> - k env1 $ \env' -> - ELInl ext (applySubstruc s2 t) (mke env') - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELInr _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsLEither s1 s2 -> - occCountX s2 e $ \env1 mke -> - k env1 $ \env' -> - ELInr ext (applySubstruc s1 t) (mke env') - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELCase _ e a b c -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2' mkb -> - occCountX s c $ \env3' mkc -> - occEnvPop' env2' $ \env2 s1 -> - occEnvPop' env3' $ \env3 s2 -> - occCountX (SsLEither s1 s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> - k env $ \env' -> - ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) - - EConstArr _ n t x -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) - SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) - - EBuild _ n a b -> - case s of - SsNone -> - occCountX SsFull a $ \env1 mka -> - occCountX SsNone b $ \env2'' mkb -> - occEnvPop' env2'' $ \env2' s2 -> - withSome (Some env1 <> scaleMany (Some env2')) $ \env -> - k env $ \env' -> - use (EBuild ext n (mka env') $ - use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ - ENil ext) $ - ENil ext - SsArr' s' -> - occCountX SsFull a $ \env1 mka -> - occCountX s' b $ \env2'' mkb -> - occEnvPop' env2'' $ \env2' s2 -> - withSome (Some env1 <> scaleMany (Some env2')) $ \env -> - k env $ \env' -> - EBuild ext n (mka env') $ - elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) - - EMap _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1 -> - occCountX (SsArr s1) b $ \env2 mkb -> - withSome (scaleMany (Some env1') <> Some env2) $ \env -> - k env $ \env' -> - use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ - ENil ext - SsArr' s' -> - occCountX s' a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1 -> - occCountX (SsArr s1) b $ \env2 mkb -> - withSome (scaleMany (Some env1') <> Some env2) $ \env -> - k env $ \env' -> - EMap ext (mka (OccPush env' () s1)) (mkb env') - - EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1' -> - let s1 = case s1' of - SsNone -> Some SsNone - SsPair' s1'a s1'b -> Some s1'a <> Some s1'b - s0 = case s of - SsNone -> Some SsNone - SsArr' s' -> Some s' in - withSome (s1 <> s0) $ \sElt -> - occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsArr sElt) s $ - EFold1Inner ext commut - (projectSmallerSubstruc SsFull sElt $ - mka (OccPush env' () (SsPair sElt sElt))) - (mkb env') (mkc env') - - ESum1Inner _ e -> handleReduction (ESum1Inner ext) e - - EUnit _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr' s' -> - occCountX s' e $ \env mke -> - k env $ \env' -> - EUnit ext (mke env') - - EReplicate1Inner _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' s' -> - occCountX SsFull a $ \env1 mka -> - occCountX (SsArr s') b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EReplicate1Inner ext (mka env') (mkb env') - - EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e - EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e - - EReshape _ n esh e -> - case s of - SsNone -> - occCountX SsNone esh $ \env1 mkesh -> - occCountX SsNone e $ \env2 mke -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkesh env') $ use (mke env') $ ENil ext - SsArr' s' -> - occCountX SsFull esh $ \env1 mkesh -> - occCountX (SsArr s') e $ \env2 mke -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EReshape ext n (mkesh env') (mke env') - - EZip _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' SsNone -> - occCountX (SsArr SsNone) a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkb env') $ mka env' - SsArr' (SsPair' SsNone s2) -> - occCountX SsNone a $ \env1 mka -> - occCountX (SsArr s2) b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ - emap (EPair ext (ENil ext) (evar IZ)) (mkb env') - SsArr' (SsPair' s1 SsNone) -> - occCountX (SsArr s1) a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkb env') $ - emap (EPair ext (evar IZ) (ENil ext)) (mka env') - SsArr' (SsPair' s1 s2) -> - occCountX (SsArr s1) a $ \env1 mka -> - occCountX (SsArr s2) b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EZip ext (mka env') (mkb env') - - EFold1InnerD1 _ cm e1 e2 e3 -> - case s of - -- If nothing is necessary, we can execute a fold and then proceed to ignore it - SsNone -> - let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) - (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) - in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex - -- If we don't need the stores, still a fold suffices - SsPair' sP SsNone -> - let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) - (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) - in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) - -- If for whatever reason the additional stores themselves are - -- unnecessary but the shape of the array is, then oblige - SsPair' sP (SsArr' SsNone) -> - let STArr sn _ = typeOf e3 - foldex = - elet (mapExt (\_ -> ext) e3) $ - EPair ext - (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) - (mapExt (\_ -> ext) (weakenExpr WSink e2)) - (evar IZ)) - in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> - k env1 $ \env' -> - eunPair (mkfoldex env') $ \_ eshape earr -> - EPair ext earr (EBuild ext sn eshape (ENil ext)) - -- If at least some of the additional stores are required, we need to keep this a mapAccum - SsPair' _ (SsArr' sB) -> - -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> - occEnvPop' env1_1' $ \env1' _ -> - occCountX SsFull e2 $ \env2 mkb -> - occCountX SsFull e3 $ \env3 mkc -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) - (mkb env') (mkc env') - - EFold1InnerD2 _ cm ef ebog ed -> - -- TODO: propagate usage of duals - occCountX SsFull ef $ \env1_2' mkef -> - occEnvPop' env1_2' $ \env1_1' _ -> - occEnvPop' env1_1' $ \env1' sB -> - occCountX (SsArr sB) ebog $ \env2 mkebog -> - occCountX SsFull ed $ \env3 mked -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s $ - EFold1InnerD2 ext cm - (mkef (OccPush (OccPush env' () sB) () SsFull)) - (mkebog env') (mked env') - - EConst _ t x -> - k OccEnd $ \_ -> - case s of - SsNone -> ENil ext - SsFull -> EConst ext t x - - EIdx0 _ e -> - occCountX (SsArr s) e $ \env1 mke -> - k env1 $ \env' -> - EIdx0 ext (mke env') - - EIdx1 _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' s' -> - occCountX (SsArr s') a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EIdx1 ext (mka env') (mkb env') - - EIdx _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - _ -> - occCountX (SsArr s) a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EIdx ext (mka env') (mkb env') - - EShape _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - _ -> - occCountX (SsArr SsNone) e $ \env1 mke -> - k env1 $ \env' -> - projectSmallerSubstruc SsFull s $ EShape ext (mke env') - - EOp _ op e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - _ -> - occCountX SsFull e $ \env1 mke -> - k env1 $ \env' -> - projectSmallerSubstruc SsFull s $ EOp ext op (mke env') - - ECustom _ t1 t2 t3 e1 e2 e3 a b - | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> - error "Accumulators not allowed in input/output/tape of an ECustom" - | otherwise -> - case s of - SsNone -> - -- Allowed to ignore e1/e2/e3 here because no accumulators are - -- communicated, and hence no relevant effects exist - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - s' -> -- Let's be pessimistic for safety - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s' $ - ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') - - ERecompute _ e -> - occCountX s e $ \env1 mke -> - k env1 $ \env' -> - ERecompute ext (mke env') - - EWith _ t a b -> - case s of - SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with - occCountX SsNone b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s1 -> - withSome (case s1 of - SsFull -> Some SsFull - SsAccum s' -> Some s' - SsNone -> Some SsNone) $ \s1' -> - occCountX s1' a $ \env1 mka -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ - ENil ext - SsPair sB sA -> - occCountX sB b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s1 -> - let s1' = case s1 of - SsFull -> Some SsFull - SsAccum s' -> Some s' - SsNone -> Some SsNone in - withSome (Some sA <> s1') $ \sA' -> - occCountX sA' a $ \env1 mka -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ - EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) - SsFull -> occCountX (SsPair SsFull SsFull) topexpr k - - EAccum _ t p a sp b e -> - -- TODO: do better! - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - occCountX SsFull e $ \env3 mke -> - withSome (Some env1 <> Some env2) $ \env12 -> - withSome (Some env12 <> Some env3) $ \env -> - k env $ \env' -> - case s of {SsFull -> id; SsNone -> id} $ - EAccum ext t p (mka env') sp (mkb env') (mke env') - - EZero _ t e -> - occCountX (subZeroInfo s) e $ \env1 mke -> - k env1 $ \env' -> - EZero ext (applySubstrucM s t) (mke env') - where - subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) - subZeroInfo SsFull = SsFull - subZeroInfo SsNone = SsNone - subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) - subZeroInfo SsEither{} = error "Either is not a monoid" - subZeroInfo SsLEither{} = SsNone - subZeroInfo SsMaybe{} = SsNone - subZeroInfo (SsArr s') = SsArr (subZeroInfo s') - subZeroInfo SsAccum{} = error "Accum is not a monoid" - - EDeepZero _ t e -> - occCountX (subDeepZeroInfo s) e $ \env1 mke -> - k env1 $ \env' -> - EDeepZero ext (applySubstrucM s t) (mke env') - where - subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) - subDeepZeroInfo SsFull = SsFull - subDeepZeroInfo SsNone = SsNone - subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) - subDeepZeroInfo SsEither{} = error "Either is not a monoid" - subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) - subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') - subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') - subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" - - EPlus _ t a b -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EPlus ext (applySubstrucM s t) (mka env') (mkb env') - - EOneHot _ t p a b -> - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> -- TODO: do better - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') - - EError _ t msg -> - k OccEnd $ \_ -> EError ext (applySubstruc s t) msg - where - s = simplifySubstruc (typeOf topexpr) initialS - - handleReduction :: t ~ TArr n (TScal t2) - => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) - -> Expr x env (TArr (S n) (TScal t2)) - -> r - handleReduction reduce e - | STArr (SS n) _ <- typeOf e = - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr' SsNone -> - occCountX (SsArr SsNone) e $ \env mke -> - k env $ \env' -> - elet (mke env') $ - EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr' SsFull -> - occCountX (SsArr SsFull) e $ \env mke -> - k env $ \env' -> - reduce (mke env') - - -deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil (Some OccEnd) k = k SETop -deleteUnused (_ `SCons` env) (Some OccEnd) k = - deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = - deleteUnused env (Some occenv) $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYesR sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYesR _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs deleted file mode 100644 index 85faba3..0000000 --- a/src/AST/Env.hs +++ /dev/null @@ -1,95 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Env where - -import Data.Type.Equality - -import AST.Sparse -import AST.Weaken -import CHAD.Types -import Data - - --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv' s env env' where - SETop :: Subenv' s '[] '[] - SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') - SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' -deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') - -type Subenv = Subenv' (:~:) -type SubenvS = Subenv' Sparse - -pattern SEYesR :: forall tenv tenv'. () - => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') - => Subenv env env' -> Subenv tenv tenv' -pattern SEYesR s = SEYes Refl s - -{-# COMPLETE SETop, SEYesR, SENo #-} - -subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env -subenvAll SNil = SETop -subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) - -subenvNone :: SList f env -> Subenv' s env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] -subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) -subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) -subenvOnehot SNil i _ = case i of {} - -subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 -subenvCompose SETop SETop = SETop -subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) -subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) -subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) - -subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') -subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) -subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) - --- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 --- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r --- subenvSplit SNil sub k = k SETop sub --- subenvSplit (SCons _ list) (SENo sub) k = --- subenvSplit list sub $ \sub1 sub2 -> --- k (SENo sub1) sub2 --- subenvSplit (SCons _ list) (SEYes s sub) k = --- subenvSplit list sub $ \sub1 sub2 -> --- k (SEYes s sub1) sub2 - -sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 -sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub -sinkWithSubenv (SENo sub) = sinkWithSubenv sub - -wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env -wUndoSubenv SETop = WId -wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) -wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub - -subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' -subenvMap _ SNil SETop = SETop -subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) -subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) - -subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') -subenvD2E SETop = SETop -subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) -subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs deleted file mode 100644 index bbcfd9e..0000000 --- a/src/AST/Pretty.hs +++ /dev/null @@ -1,525 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where - -import Control.Monad (ap) -import Data.List (intersperse, intercalate) -import Data.Functor.Const -import qualified Data.Functor.Product as Product -import Data.String (fromString) -import Prettyprinter -import Prettyprinter.Render.String - -import qualified Data.Text.Lazy as TL -import qualified Prettyprinter.Render.Terminal as PT -import System.Console.ANSI (hSupportsANSI) -import System.IO (stdout) -import System.IO.Unsafe (unsafePerformIO) - -import AST -import AST.Count -import AST.Sparse.Types -import CHAD.Types -import Data - - -class PrettyX x where - prettyX :: x t -> String - - prettyXsuffix :: x t -> String - prettyXsuffix x = "<" ++ prettyX x ++ ">" - -instance PrettyX (Const ()) where - prettyX _ = "" - prettyXsuffix _ = "" - - -type SVal = SList (Const String) - -newtype M a = M { runM :: Int -> (a, Int) } - deriving (Functor) -instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } -instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) } - -genId :: M Int -genId = M (\i -> (i, i + 1)) - -nameBaseForType :: STy t -> String -nameBaseForType STNil = "nil" -nameBaseForType (STPair{}) = "p" -nameBaseForType (STEither{}) = "e" -nameBaseForType (STMaybe{}) = "m" -nameBaseForType (STScal STI32) = "n" -nameBaseForType (STScal STI64) = "n" -nameBaseForType (STArr{}) = "a" -nameBaseForType (STAccum{}) = "ac" -nameBaseForType _ = "x" - -genName' :: String -> M String -genName' prefix = (prefix ++) . show <$> genId - -genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String -genNameIfUsedIn' prefix ty idx ex - | occCount idx ex == mempty = case ty of STNil -> return "()" - _ -> return "_" - | otherwise = genName' prefix - --- TODO: let this return a type-tagged thing so that name environments are more typed than Const -genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String -genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t - -pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO () -pprintExpr = putStrLn . ppExpr knownEnv - -ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String -ppExpr senv e = render $ fst . flip runM 1 $ do - val <- mkVal senv - e' <- ppExpr' 0 val e - let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." - return $ group $ flatAlt - (hang 2 $ - ppString lam - <> hardline <> e') - (ppString lam <+> e') - where - mkVal :: SList f env -> M (SVal env) - mkVal SNil = return SNil - mkVal (SCons _ v) = do - val <- mkVal v - name <- genName' "arg" - return (Const name `SCons` val) - -ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc -ppExpr' d val expr = case expr of - EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr - - e@ELet{} -> ppExprLet d val e - - EPair _ a b -> do - a' <- ppExpr' 0 val a - b' <- ppExpr' 0 val b - return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr) - (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr) - - EFst _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e' - - ESnd _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e' - - ENil _ -> return $ ppString "()" - - EInl _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e' - - EInr _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e' - - ECase _ e a b -> do - e' <- ppExpr' 0 val e - let STEither t1 t2 = typeOf e - name1 <- genNameIfUsedIn t1 IZ a - a' <- ppExpr' 0 (Const name1 `SCons` val) a - name2 <- genNameIfUsedIn t2 IZ b - b' <- ppExpr' 0 (Const name2 `SCons` val) b - return $ ppParen (d > 0) $ - hang 2 $ - annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of") - <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a' - <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b' - - ENothing _ _ -> return $ ppString "Nothing" - - EJust _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e' - - EMaybe _ a b e -> do - let STMaybe t = typeOf e - e' <- ppExpr' 0 val e - a' <- ppExpr' 0 val a - name <- genNameIfUsedIn t IZ b - b' <- ppExpr' 0 (Const name `SCons` val) b - return $ ppParen (d > 0) $ - align $ - group (flatAlt - (annotate AKey (ppString "case") <> ppX expr <+> e' - <> hardline <> annotate AKey (ppString "of")) - (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) - <> hardline - <> indent 2 - (ppString "Nothing" <+> ppString "->" <+> a' - <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') - - ELNil _ _ _ -> return (ppString "LNil") - - ELInl _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' - - ELInr _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' - - ELCase _ e a b c -> do - e' <- ppExpr' 0 val e - let STLEither t1 t2 = typeOf e - a' <- ppExpr' 11 val a - name1 <- genNameIfUsedIn t1 IZ b - b' <- ppExpr' 0 (Const name1 `SCons` val) b - name2 <- genNameIfUsedIn t2 IZ c - c' <- ppExpr' 0 (Const name2 `SCons` val) c - return $ ppParen (d > 0) $ - hang 2 $ - annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") - <> hardline <> ppString "LNil" <+> ppString "->" <+> a' - <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' - <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' - - EConstArr _ _ ty v - | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr - - EBuild _ n a b -> do - a' <- ppExpr' 11 val a - name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b - e' <- ppExpr' 0 (Const name `SCons` val) b - let primName = ppString ("build" ++ intSubscript (fromSNat n)) - return $ ppParen (d > 0) $ - group $ flatAlt - (hang 2 $ - annotate AHighlight primName <> ppX expr <+> a' - <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" - <> hardline <> e') - (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) - - EMap _ a b -> do - let STArr _ t1 = typeOf b - name <- genNameIfUsedIn t1 IZ a - a' <- ppExpr' 0 (Const name `SCons` val) a - b' <- ppExpr' 11 val b - return $ ppParen (d > 0) $ - ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] - - EFold1Inner _ cm a b c -> do - name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a - a' <- ppExpr' 0 (Const name `SCons` val) a - b' <- ppExpr' 11 val b - c' <- ppExpr' 11 val c - let opname = "fold1i" ++ ppCommut cm - return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] - - ESum1Inner _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e' - - EUnit _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e' - - EReplicate1Inner _ a b -> do - a' <- ppExpr' 11 val a - b' <- ppExpr' 11 val b - return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b'] - - EMaximum1Inner _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e' - - EMinimum1Inner _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' - - EReshape _ n esh e -> do - esh' <- ppExpr' 11 val esh - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] - - EZip _ e1 e2 -> do - e1' <- ppExpr' 11 val e1 - e2' <- ppExpr' 11 val e2 - return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] - - EFold1InnerD1 _ cm a b c -> do - name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a - a' <- ppExpr' 0 (Const name `SCons` val) a - b' <- ppExpr' 11 val b - c' <- ppExpr' 11 val c - let opname = "fold1iD1" ++ ppCommut cm - return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] - - EFold1InnerD2 _ cm ef ebog ed -> do - let STArr _ tB = typeOf ebog - STArr _ t2 = typeOf ed - namef1 <- genNameIfUsedIn tB (IS IZ) ef - namef2 <- genNameIfUsedIn t2 IZ ef - ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef - ebog' <- ppExpr' 11 val ebog - ed' <- ppExpr' 11 val ed - let opname = "fold1iD2" ++ ppCommut cm - return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) - [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] - - EConst _ ty v - | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr - - EIdx0 _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e' - - EIdx1 _ a b -> do - a' <- ppExpr' 9 val a - b' <- ppExpr' 9 val b - return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b' - - EIdx _ a b -> do - a' <- ppExpr' 9 val a - b' <- ppExpr' 10 val b - return $ ppParen (d > 8) $ - a' <+> ppString "!" <> ppX expr <+> b' - - EShape _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e' - - EOp _ op (EPair _ a b) - | (Infix, ops) <- operator op -> do - a' <- ppExpr' 9 val a - b' <- ppExpr' 9 val b - return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b' - - EOp _ op e -> do - e' <- ppExpr' 11 val e - let ops = case operator op of - (Infix, s) -> "(" ++ s ++ ")" - (Prefix, s) -> s - return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e' - - ECustom _ t1 t2 t3 a b c e1 e2 -> do - en1 <- genNameIfUsedIn t1 (IS IZ) a - en2 <- genNameIfUsedIn t2 IZ a - pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b - pn2 <- genNameIfUsedIn (d1 t2) IZ b - dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c - dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c - a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a - b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b - c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c - e1' <- ppExpr' 11 val e1 - e2' <- ppExpr' 11 val e2 - return $ ppParen (d > 10) $ - ppApp (ppString "custom" <> ppX expr) - [ppLam [ppString en1, ppString en2] a' - ,ppLam [ppString pn1, ppString pn2] b' - ,ppLam [ppString dn1, ppString dn2] c' - ,e1' - ,e2'] - - ERecompute _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] - - EWith _ t e1 e2 -> do - e1' <- ppExpr' 11 val e1 - name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 - e2' <- ppExpr' 0 (Const name `SCons` val) e2 - return $ ppParen (d > 0) $ - group $ flatAlt - (hang 2 $ - annotate AWith (ppString "with") <> ppX expr <+> e1' - <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" - <> hardline <> e2') - (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - - EAccum _ t prj e1 sp e2 e3 -> do - e1' <- ppExpr' 11 val e1 - e2' <- ppExpr' 11 val e2 - e3' <- ppExpr' 11 val e3 - return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) - [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] - - EZero _ t e1 -> do - e1' <- ppExpr' 11 val e1 - return $ ppParen (d > 0) $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - - EDeepZero _ t e1 -> do - e1' <- ppExpr' 11 val e1 - return $ ppParen (d > 0) $ - annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - - EPlus _ t a b -> do - a' <- ppExpr' 11 val a - b' <- ppExpr' 11 val b - return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b'] - - EOneHot _ t prj a b -> do - a' <- ppExpr' 11 val a - b' <- ppExpr' 11 val b - return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b'] - - EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) - -ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc -ppExprLet d val etop = do - let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) - collect val' (ELet _ rhs body) = do - let occ = occCount IZ body - name <- genNameIfUsedIn (typeOf rhs) IZ body - rhs' <- ppExpr' 0 val' rhs - (binds, core) <- collect (Const name `SCons` val') body - return ((name, occ, rhs') : binds, core) - collect val' e = ([],) <$> ppExpr' 0 val' e - - (binds, core) <- collect val etop - - return $ ppParen (d > 0) $ - align $ - annotate AKey (ppString "let") - <+> align (mconcat $ intersperse hardline $ - map (\(name, _occ, rhs) -> - ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs) - binds) - <> hardline <> annotate AKey (ppString "in") <+> core - -ppApp :: ADoc -> [ADoc] -> ADoc -ppApp fun args = group $ fun <+> align (sep args) - -ppLam :: [ADoc] -> ADoc -> ADoc -ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) - <> softline <> body <> ppString ")") - -ppAcPrj :: SMTy a -> SAcPrj p a b -> String -ppAcPrj _ SAPHere = "." -ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" -ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" -ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" -ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" -ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj -ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) - -ppSparse :: SMTy a -> Sparse a b -> String -ppSparse t sp | Just Refl <- isDense t sp = "D" -ppSparse _ SpAbsent = "A" -ppSparse t (SpSparse s) = "S" ++ ppSparse t s -ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" -ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" -ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s -ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s -ppSparse (SMTScal _) SpScal = "." - -ppCommut :: Commutative -> String -ppCommut Commut = "(C)" -ppCommut Noncommut = "" - -ppX :: PrettyX x => Expr x env t -> ADoc -ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) - -data Fixity = Prefix | Infix - deriving (Show) - -operator :: SOp a t -> (Fixity, String) -operator OAdd{} = (Infix, "+") -operator OMul{} = (Infix, "*") -operator ONeg{} = (Prefix, "negate") -operator OLt{} = (Infix, "<") -operator OLe{} = (Infix, "<=") -operator OEq{} = (Infix, "==") -operator ONot = (Prefix, "not") -operator OAnd = (Infix, "&&") -operator OOr = (Infix, "||") -operator OIf = (Prefix, "ifB") -operator ORound64 = (Prefix, "round") -operator OToFl64 = (Prefix, "toFl64") -operator ORecip{} = (Prefix, "recip") -operator OExp{} = (Prefix, "exp") -operator OLog{} = (Prefix, "log") -operator OIDiv{} = (Infix, "`div`") -operator OMod{} = (Infix, "`mod`") - -ppSTy :: Int -> STy t -> String -ppSTy d ty = render $ ppSTy' d ty - -ppSTy' :: Int -> STy t -> Doc q -ppSTy' _ STNil = ppString "1" -ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b -ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b -ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b -ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t -ppSTy' d (STArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t -ppSTy' _ (STScal sty) = ppString $ case sty of - STI32 -> "i32" - STI64 -> "i64" - STF32 -> "f32" - STF64 -> "f64" - STBool -> "bool" -ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t - -ppSMTy :: Int -> SMTy t -> String -ppSMTy d ty = render $ ppSMTy' d ty - -ppSMTy' :: Int -> SMTy t -> Doc q -ppSMTy' _ SMTNil = ppString "1" -ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b -ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b -ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t -ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t -ppSMTy' _ (SMTScal sty) = ppString $ case sty of - STI32 -> "i32" - STI64 -> "i64" - STF32 -> "f32" - STF64 -> "f64" - -ppString :: String -> Doc x -ppString = fromString - -ppParen :: Bool -> Doc x -> Doc x -ppParen True = parens -ppParen False = id - -intSubscript :: Int -> String -intSubscript = \case 0 -> "₀" - n | n < 0 -> '₋' : go (-n) "" - | otherwise -> go n "" - where go 0 suff = suff - go n suff = let (q, r) = n `quotRem` 10 - in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff) - -data Annot = AKey | AWith | AHighlight | AMonoid | AExt - deriving (Show) - -annotToANSI :: Annot -> PT.AnsiStyle -annotToANSI AKey = PT.bold -annotToANSI AWith = PT.color PT.Red <> PT.underlined -annotToANSI AHighlight = PT.color PT.Blue -annotToANSI AMonoid = PT.color PT.Green -annotToANSI AExt = PT.colorDull PT.White - -type ADoc = Doc Annot - -render :: Doc Annot -> String -render = - (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI - else renderString) - . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } - where - stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs deleted file mode 100644 index 2a29799..0000000 --- a/src/AST/Sparse.hs +++ /dev/null @@ -1,287 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE RankNTypes #-} - -{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where - -import Data.Type.Equality - -import AST -import AST.Sparse.Types -import Data (SBool(..)) - - -sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' -sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext -sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 -sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh -sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = - eunPair e1 $ \w1 e1a e1b -> - eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> - EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) - (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) -sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = - elet e2 $ - elcase (weakenExpr WSink e1) - (evar IZ) - (elcase (evar (IS IZ)) - (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) - (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) - (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) - (elcase (evar (IS IZ)) - (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) - (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") - (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) -sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = - elet e2 $ - emaybe (weakenExpr WSink e1) - (evar IZ) - (emaybe (evar (IS IZ)) - (EJust ext (evar IZ)) - (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) -sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 -sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 - - -cheapZero :: SMTy t -> Maybe (forall env. Ex env t) -cheapZero SMTNil = Just (ENil ext) -cheapZero (SMTPair t1 t2) - | Just e1 <- cheapZero t1 - , Just e2 <- cheapZero t2 - = Just (EPair ext e1 e2) - | otherwise - = Nothing -cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) -cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) -cheapZero SMTArr{} = Nothing -cheapZero (SMTScal t) = case t of - STI32 -> Just (EConst ext t 0) - STI64 -> Just (EConst ext t 0) - STF32 -> Just (EConst ext t 0.0) - STF64 -> Just (EConst ext t 0.0) - - -data Injection sp a b where - -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that - -- 'sparsePlusS' can provide injections even if the caller doesn't require - -- them. This simplifies the sparsePlusS code. - Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b - Noinj :: Injection False a b - -withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' -withInj (Inj f) k = Inj (k f) -withInj Noinj _ = Noinj - -withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 - -> ((forall e. Ex e a1 -> Ex e b1) - -> (forall e. Ex e a2 -> Ex e b2) - -> (forall e'. Ex e' a' -> Ex e' b')) - -> Injection sp a' b' -withInj2 (Inj f) (Inj g) k = Inj (k f g) -withInj2 Noinj _ _ = Noinj -withInj2 _ Noinj _ = Noinj - --- | This function produces quadratically-sized code in the presence of nested --- dynamic sparsity. TODO can this be improved? -sparsePlusS - :: SBool inj1 -> SBool inj2 - -> SMTy t -> Sparse t t1 -> Sparse t t2 - -> (forall t3. Sparse t t3 - -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) - -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) - -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) - -> r) - -> r --- nil override (but don't destroy effects!) -sparsePlusS _ _ SMTNil _ _ k = - k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) - --- simplifications -sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = - sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> - k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) -sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = - sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> - k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) - -sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = - let ta = applySparse sp1 (fromSMTy t) in - sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) - minj2 - (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = - let tb = applySparse sp2 (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) - (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) - -sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = - let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in - sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - minj2 - (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = - let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - -sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = - let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in - sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) - minj2 - (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = - let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) - (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) -sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k -sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k - --- TODO: sparse of Just is just Maybe - --- dense plus -sparsePlusS _ _ t sp1 sp2 k - | Just Refl <- isDense t sp1 - , Just Refl <- isDense t sp2 - = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) - --- handle absents -sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) -sparsePlusS ST _ t SpAbsent sp2 k - | Just zero2 <- cheapZero (applySparse sp2 t) = - k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) - | otherwise = - k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) - -sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) -sparsePlusS _ ST t sp1 SpAbsent k - | Just zero1 <- cheapZero (applySparse sp1 t) = - k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) - | otherwise = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) - --- double sparse yields sparse -sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (emaybe (evar IZ) - (ENothing ext (applySparse sp3 (fromSMTy t))) - (EJust ext (inj2 (evar IZ)))) - (emaybe (evar (IS IZ)) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ))))) - --- single sparse can yield non-sparse if the other argument is always present -sparsePlusS SF _ t (SpSparse sp1) sp2 k = - sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> - k sp3 Noinj (Inj inj2) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (inj2 (evar IZ)) - (plus (evar IZ) (evar (IS IZ)))) -sparsePlusS ST _ t (SpSparse sp1) sp2 k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> EJust ext (inj2 b)) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (EJust ext (inj2 (evar IZ))) - (EJust ext (plus (evar IZ) (evar (IS IZ))))) -sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = - sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> - k sp3 inj2 inj1 (flip plus) - --- products -sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = - sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> - sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> - k (SpPair sp3a sp3b) - (withInj2 minj13a minj13b $ \inj13a inj13b -> - \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) - (withInj2 minj23a minj23b $ \inj23a inj23b -> - \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) - (\x1 x2 -> - eunPair x1 $ \w1 x1a x1b -> - eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> - EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) - --- coproducts -sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = - sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> - sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> - let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) - inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) - inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) - in - k (SpLEither sp3a sp3b) - (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) - (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) - (\x1 x2 -> - elet x2 $ - elcase (weakenExpr WSink x1) - (elcase (evar IZ) - nil - (inl (inj23a (evar IZ))) - (inr (inj23b (evar IZ)))) - (elcase (evar (IS IZ)) - (inl (inj13a (evar IZ))) - (inl (plusa (evar (IS IZ)) (evar IZ))) - (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) - (elcase (evar (IS IZ)) - (inr (inj13b (evar IZ))) - (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") - (inr (plusb (evar (IS IZ)) (evar IZ))))) - --- maybe -sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpMaybe sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (emaybe (evar IZ) - (ENothing ext (applySparse sp3 (fromSMTy t))) - (EJust ext (inj2 (evar IZ)))) - (emaybe (evar (IS IZ)) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ))))) - --- dense array cotangents simply recurse -sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = - sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> - k (SpArr sp3) - (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) - (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) - (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) - (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) - --- scalars -sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs deleted file mode 100644 index 10cac4e..0000000 --- a/src/AST/Sparse/Types.hs +++ /dev/null @@ -1,107 +0,0 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module AST.Sparse.Types where - -import AST.Types - -import Data.Kind (Type, Constraint) -import Data.Type.Equality - - -data Sparse t t' where - SpSparse :: Sparse t t' -> Sparse t (TMaybe t') - SpAbsent :: Sparse t TNil - - SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') - SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') - SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') - SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') - SpScal :: Sparse (TScal t) (TScal t) -deriving instance Show (Sparse t t') - -class ApplySparse f where - applySparse :: Sparse t t' -> f t -> f t' - -instance ApplySparse STy where - applySparse (SpSparse s) t = STMaybe (applySparse s t) - applySparse SpAbsent _ = STNil - applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) - applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) - applySparse SpScal t = t - -instance ApplySparse SMTy where - applySparse (SpSparse s) t = SMTMaybe (applySparse s t) - applySparse SpAbsent _ = SMTNil - applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) - applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) - applySparse SpScal t = t - - -class IsSubType s where - type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint - subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c - subtFull :: IsSubTypeSubject s f => f t -> s t t - -instance IsSubType (:~:) where - type IsSubTypeSubject (:~:) f = () - subtApply = gcastWith - subtTrans = trans - subtFull _ = Refl - -instance IsSubType Sparse where - type IsSubTypeSubject Sparse f = f ~ SMTy - subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - - subtFull = spDense - -spDense :: SMTy t -> Sparse t t -spDense SMTNil = SpAbsent -spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) -spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) -spDense (SMTMaybe t) = SpMaybe (spDense t) -spDense (SMTArr _ t) = SpArr (spDense t) -spDense (SMTScal _) = SpScal - -isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') -isDense SMTNil SpAbsent = Just Refl -isDense _ SpSparse{} = Nothing -isDense _ SpAbsent = Nothing -isDense (SMTPair t1 t2) (SpPair s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTLEither t1 t2) (SpLEither s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTMaybe t) (SpMaybe s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTArr _ t) (SpArr s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTScal _) SpScal = Just Refl - -isAbsent :: Sparse t t' -> Bool -isAbsent (SpSparse s) = isAbsent s -isAbsent SpAbsent = True -isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpMaybe s) = isAbsent s -isAbsent (SpArr s) = isAbsent s -isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs deleted file mode 100644 index 267dd87..0000000 --- a/src/AST/SplitLets.hs +++ /dev/null @@ -1,191 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where - -import Data.Type.Equality - -import AST -import AST.Bindings -import Lemmas - - -splitLets :: Ex env t -> Ex env t -splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) - -splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t -splitLets' = \sub -> \case - EVar _ t i -> sub t i WId - ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) - ECase x e a b -> - let STEither t1 t2 = typeOf e - in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) - EMaybe x a b e -> - let STMaybe t1 = typeOf e - in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) - ELCase x e a b c -> - let STLEither t1 t2 = typeOf e - in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) - EFold1Inner x cm a b c -> - let STArr _ t1 = typeOf c - in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD1 x cm a b c -> - let STArr _ t1 = typeOf c - in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD2 x cm a b c -> - let STArr _ tB = typeOf b - STArr _ t2 = typeOf c - in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) - - EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) - EFst x e -> EFst x (splitLets' sub e) - ESnd x e -> ESnd x (splitLets' sub e) - ENil x -> ENil x - EInl x t e -> EInl x t (splitLets' sub e) - EInr x t e -> EInr x t (splitLets' sub e) - ENothing x t -> ENothing x t - EJust x e -> EJust x (splitLets' sub e) - ELNil x t1 t2 -> ELNil x t1 t2 - ELInl x t e -> ELInl x t (splitLets' sub e) - ELInr x t e -> ELInr x t (splitLets' sub e) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) - EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) - ESum1Inner x e -> ESum1Inner x (splitLets' sub e) - EUnit x e -> EUnit x (splitLets' sub e) - EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) - EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) - EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) - EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) - EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (splitLets' sub e) - EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) - EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) - EShape x e -> EShape x (splitLets' sub e) - EOp x op e -> EOp x op (splitLets' sub e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) - ERecompute x e -> ERecompute x (splitLets' sub e) - EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) - EZero x t ezi -> EZero x t (splitLets' sub ezi) - EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) - EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) - EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) - EError x t s -> EError x t s - where - sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t - sinkF _ t IZ w = EVar ext t (w @> IZ) - sinkF f t (IS i) w = f t i (w .> WSink) - - split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t - split1 sub (tbind :: STy bind) body = - let (ptrs, bs) = split tbind - in letBinds bs $ - splitLets' (\cases _ IZ w -> subPointers ptrs w - t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) - body - - split2 :: forall bind1 bind2 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t - split2 sub tbind1 tbind2 body = - let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindingsE WSink bs1') - (ptrs2, bs2) = split @(bind1 : env') tbind2 - in letBinds bs1 $ - letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ - splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) - _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) - t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) - body - - -- TODO: abstract this to splitN lol wtf - _split4 :: forall bind1 bind2 bind3 bind4 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t - _split4 sub tbind1 tbind2 tbind3 tbind4 body = - let (ptrs1, bs1') = split @env' tbind1 - (ptrs2, bs2') = split @(bind1 : env') tbind2 - (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 - (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 - bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') - bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') - bs3 = fst (weakenBindingsE WSink bs3') - b1 = bindingsBinds bs1 - b2 = bindingsBinds bs2 - b3 = bindingsBinds bs3 - b4 = bindingsBinds bs4 - in letBinds bs1 $ - letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ - letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ - letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ - splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) - _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) - _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) - _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) - t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) - body - -type family Split t where - Split (TPair a b) = SplitRec (TPair a b) - Split _ = '[] - -type family SplitRec t where - SplitRec TNil = '[] - SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a) - SplitRec t = '[t] - -data Pointers env t where - Point :: STy t -> Idx env t -> Pointers env t - PNil :: Pointers env TNil - PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) - PWeak :: env' :> env -> Pointers env' t -> Pointers env t - -subPointers :: Pointers env t -> env :> env' -> Ex env' t -subPointers (Point t i) w = EVar ext t (w @> i) -subPointers PNil _ = ENil ext -subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) -subPointers (PWeak w' p) w = subPointers p (w .> w') - -split :: forall env t. STy t - -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t)) -split typ = case typ of - STPair{} -> splitRec (EVar ext typ IZ) typ - STNil -> other - STEither{} -> other - STLEither{} -> other - STMaybe{} -> other - STArr{} -> other - STScal{} -> other - STAccum{} -> other - where - other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) - other = (Point typ IZ, BTop) - -splitRec :: forall env t. Ex env t -> STy t - -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs typ = case typ of - STNil -> (PNil, BTop) - STPair (a :: STy a) (b :: STy b) - | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> - let (p1, bs1) = splitRec (EFst ext rhs) a - (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b - in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) - STEither{} -> other - STLEither{} -> other - STMaybe{} -> other - STArr{} -> other - STScal{} -> other - STAccum{} -> other - where - other :: (Pointers (t : env) t, Bindings Ex env '[t]) - other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs deleted file mode 100644 index 4ddcb50..0000000 --- a/src/AST/Types.hs +++ /dev/null @@ -1,215 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeData #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module AST.Types where - -import Data.Int (Int32, Int64) -import Data.GADT.Compare -import Data.GADT.Show -import Data.Kind (Type) -import Data.Type.Equality - -import Data - - -type data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TLEither Ty Ty - | TMaybe Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty -- ^ contained type must be a monoid type - -type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool - -type STy :: Ty -> Type -data STy t where - STNil :: STy TNil - STPair :: STy a -> STy b -> STy (TPair a b) - STEither :: STy a -> STy b -> STy (TEither a b) - STLEither :: STy a -> STy b -> STy (TLEither a b) - STMaybe :: STy a -> STy (TMaybe a) - STArr :: SNat n -> STy t -> STy (TArr n t) - STScal :: SScalTy t -> STy (TScal t) - STAccum :: SMTy t -> STy (TAccum t) -deriving instance Show (STy t) - -instance GCompare STy where - gcompare = \cases - STNil STNil -> GEQ - STNil _ -> GLT ; _ STNil -> GGT - (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - STPair{} _ -> GLT ; _ STPair{} -> GGT - (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - STEither{} _ -> GLT ; _ STEither{} -> GGT - (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - STLEither{} _ -> GLT ; _ STLEither{} -> GGT - (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') - STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT - (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') - STArr{} _ -> GLT ; _ STArr{} -> GGT - (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') - STScal{} _ -> GLT ; _ STScal{} -> GGT - (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT - -instance TestEquality STy where testEquality = geq -instance GEq STy where geq = defaultGeq -instance GShow STy where gshowsPrec = defaultGshowsPrec - --- | Monoid types -type SMTy :: Ty -> Type -data SMTy t where - SMTNil :: SMTy TNil - SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) - SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) - SMTMaybe :: SMTy a -> SMTy (TMaybe a) - SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) - SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) -deriving instance Show (SMTy t) - -instance GCompare SMTy where - gcompare = \cases - SMTNil SMTNil -> GEQ - SMTNil _ -> GLT ; _ SMTNil -> GGT - (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT - (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT - (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') - SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT - (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') - SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT - (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') - -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT - -instance TestEquality SMTy where testEquality = geq -instance GEq SMTy where geq = defaultGeq -instance GShow SMTy where gshowsPrec = defaultGshowsPrec - -fromSMTy :: SMTy t -> STy t -fromSMTy = \case - SMTNil -> STNil - SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) - SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) - SMTMaybe t -> STMaybe (fromSMTy t) - SMTArr n t -> STArr n (fromSMTy t) - SMTScal sty -> STScal sty - -data SScalTy t where - STI32 :: SScalTy TI32 - STI64 :: SScalTy TI64 - STF32 :: SScalTy TF32 - STF64 :: SScalTy TF64 - STBool :: SScalTy TBool -deriving instance Show (SScalTy t) - -instance GCompare SScalTy where - gcompare = \cases - STI32 STI32 -> GEQ - STI32 _ -> GLT ; _ STI32 -> GGT - STI64 STI64 -> GEQ - STI64 _ -> GLT ; _ STI64 -> GGT - STF32 STF32 -> GEQ - STF32 _ -> GLT ; _ STF32 -> GGT - STF64 STF64 -> GEQ - STF64 _ -> GLT ; _ STF64 -> GGT - STBool STBool -> GEQ - -- STBool _ -> GLT ; _ STBool -> GGT - -instance TestEquality SScalTy where testEquality = geq -instance GEq SScalTy where geq = defaultGeq -instance GShow SScalTy where gshowsPrec = defaultGshowsPrec - -scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) -scalRepIsShow STI32 = Dict -scalRepIsShow STI64 = Dict -scalRepIsShow STF32 = Dict -scalRepIsShow STF64 = Dict -scalRepIsShow STBool = Dict - -type TIx = TScal TI64 - -tIx :: STy TIx -tIx = STScal STI64 - -type family ScalRep t where - ScalRep TI32 = Int32 - ScalRep TI64 = Int64 - ScalRep TF32 = Float - ScalRep TF64 = Double - ScalRep TBool = Bool - -type family ScalIsNumeric t where - ScalIsNumeric TI32 = True - ScalIsNumeric TI64 = True - ScalIsNumeric TF32 = True - ScalIsNumeric TF64 = True - ScalIsNumeric TBool = False - -type family ScalIsFloating t where - ScalIsFloating TI32 = False - ScalIsFloating TI64 = False - ScalIsFloating TF32 = True - ScalIsFloating TF64 = True - ScalIsFloating TBool = False - -type family ScalIsIntegral t where - ScalIsIntegral TI32 = True - ScalIsIntegral TI64 = True - ScalIsIntegral TF32 = False - ScalIsIntegral TF64 = False - ScalIsIntegral TBool = False - --- | Returns true for arrays /and/ accumulators. -typeHasArrays :: STy t' -> Bool -typeHasArrays STNil = False -typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b -typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b -typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b -typeHasArrays (STMaybe t) = typeHasArrays t -typeHasArrays STArr{} = True -typeHasArrays STScal{} = False -typeHasArrays STAccum{} = True - -typeHasAccums :: STy t' -> Bool -typeHasAccums STNil = False -typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b -typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b -typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b -typeHasAccums (STMaybe t) = typeHasAccums t -typeHasAccums STArr{} = False -typeHasAccums STScal{} = False -typeHasAccums STAccum{} = True - -type family Tup env where - Tup '[] = TNil - Tup (t : ts) = TPair (Tup ts) t - -mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) - -> SList f list -> f (Tup list) -mkTup nil _ SNil = nil -mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e - -tTup :: SList STy env -> STy (Tup env) -tTup = mkTup STNil STPair - -unTup :: (forall a b. c (TPair a b) -> (c a, c b)) - -> SList f list -> c (Tup list) -> SList c list -unTup _ SNil _ = SNil -unTup unpack (_ `SCons` list) tup = - let (xs, x) = unpack tup - in x `SCons` unTup unpack list xs - -type family InvTup core env where - InvTup core '[] = core - InvTup core (t : ts) = InvTup (TPair core t) ts diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs deleted file mode 100644 index 1712ba5..0000000 --- a/src/AST/UnMonoid.hs +++ /dev/null @@ -1,255 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where - -import AST -import AST.Sparse.Types -import Data - - --- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by --- expanding them into their concrete implementations. Also ensure that --- 'EAccum' has a dense sparsity. -unMonoid :: Ex env t -> Ex env t -unMonoid = \case - EZero _ t e -> zero t e - EDeepZero _ t e -> deepZero t e - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) - - EVar _ t i -> EVar ext t i - ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) - EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) - EFst _ e -> EFst ext (unMonoid e) - ESnd _ e -> ESnd ext (unMonoid e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) - ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t - EJust _ e -> EJust ext (unMonoid e) - EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - ELNil _ t1 t2 -> ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t (unMonoid e) - ELInr _ t e -> ELInr ext t (unMonoid e) - ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) - EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) - EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) - ESum1Inner _ e -> ESum1Inner ext (unMonoid e) - EUnit _ e -> EUnit ext (unMonoid e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) - EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) - EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) - EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) - EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EConst _ t x -> EConst ext t x - EIdx0 _ e -> EIdx0 ext (unMonoid e) - EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) - EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - ERecompute _ e -> ERecompute ext (unMonoid e) - EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p eidx sp eval eacc -> - accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> - acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> - EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) - EError _ t s -> EError ext t s - -zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t --- don't destroy the effects! -zero SMTNil e = ELet ext e $ ENil ext -zero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) -zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) -zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e -zero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t -deepZero SMTNil e = elet e $ ENil ext -deepZero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -deepZero (SMTLEither t1 t2) e = - elcase e - (ELNil ext (fromSMTy t1) (fromSMTy t2)) - (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) - (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) -deepZero (SMTMaybe t) e = - emaybe e - (ENothing ext (fromSMTy t)) - (EJust ext (deepZero t (evar IZ))) -deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e -deepZero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t --- don't destroy the effects! -plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext -plus (SMTPair t1 t2) a b = - let t = STPair (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (SMTLEither t1 t2) a b = - let t = STLEither (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - ELCase ext (EVar ext t (IS IZ)) - (EVar ext t IZ) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) - (EError ext t "plus l+r")) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (EError ext t "plus r+l") - (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) -plus (SMTMaybe t) a b = - ELet ext b $ - EMaybe ext - (EVar ext (STMaybe (fromSMTy t)) IZ) - (EJust ext - (EMaybe ext - (EVar ext (fromSMTy t) IZ) - (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) - (weakenExpr WSink a) -plus (SMTArr _ t) a b = - ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - a b -plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) - -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t -onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> - ELet ext arg $ - EVar ext (fromSMTy typ) IZ - - (SMTPair t1 t2, SAPFst prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t1 in - EPair ext (EVar ext toh IZ) - (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) - - (SMTPair t1 t2, SAPSnd prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t2 in - EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) - (EVar ext toh IZ) - - (SMTLEither t1 t2, SAPLeft prj) -> - ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) - (SMTLEither t1 t2, SAPRight prj) -> - ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) - - (SMTMaybe t1, SAPJust prj) -> - EJust ext (onehot t1 prj idx arg) - - (SMTArr n t1, SAPArrIdx prj) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ - zero t1 (EVar ext (tZeroInfo t1) IZ)) - -accumulateSparse - :: SMTy t -> Sparse t t' -> Ex env t' - -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) - -> Ex env TNil -accumulateSparse topty topsp arg accum = case (topty, topsp) of - (_, s) | Just Refl <- isDense topty s -> - accum WId SAPHere (ENil ext) arg - (SMTScal _, SpScal) -> - accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh - (_, SpSparse s) -> - emaybe arg - (ENil ext) - (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) - (_, SpAbsent) -> - ENil ext - (SMTPair t1 t2, SpPair s1 s2) -> - eunPair arg $ \w1 e1 e2 -> - elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ - accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) - (SMTLEither t1 t2, SpLEither s1 s2) -> - elcase arg - (ENil ext) - (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) - (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) - (SMTMaybe t, SpMaybe s) -> - emaybe arg - (ENil ext) - (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) - (SMTArr n t, SpArr s) -> - let tn = tTup (sreplicate n tIx) in - elet arg $ - elet (EBuild ext n (EShape ext (evar IZ)) $ - accumulateSparse t s - (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) - (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ - ENil ext - -acPrjCompose - :: SAIDense dense - -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) - -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r -acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 -acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = - acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPFst p') idx' -acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = - acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPSnd p') idx' -acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) -acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') -acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPLeft p') idx' -acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPRight p') idx' -acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPJust p') idx' -acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') -acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs deleted file mode 100644 index f0820b8..0000000 --- a/src/AST/Weaken.hs +++ /dev/null @@ -1,138 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} - -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} - --- The reason why this is a separate module with "little" in it: -{-# LANGUAGE AllowAmbiguousTypes #-} - -module AST.Weaken (module AST.Weaken, Append) where - -import Data.Bifunctor (first) -import Data.Functor.Const -import Data.GADT.Compare -import Data.Kind (Type) - -import Data -import Lemmas - - -type Idx :: [k] -> k -> Type -data Idx env t where - IZ :: Idx (t : env) t - IS :: Idx env t -> Idx (a : env) t -deriving instance Show (Idx env t) - -instance GEq (Idx env) where - geq IZ IZ = Just Refl - geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl - geq _ _ = Nothing - -splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) -splitIdx SNil i = Right i -splitIdx (SCons _ _) IZ = Left IZ -splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) - -slistIdx :: SList f list -> Idx list t -> f t -slistIdx (SCons x _) IZ = x -slistIdx (SCons _ list) (IS i) = slistIdx list i -slistIdx SNil i = case i of {} - -idx2int :: Idx env t -> Int -idx2int IZ = 0 -idx2int (IS n) = 1 + idx2int n - -data env :> env' where - WId :: env :> env - WSink :: forall t env. env :> (t : env) - WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env') - WPop :: (t : env) :> env' -> env :> env' - WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 - WClosed :: '[] :> env - WIdx :: Idx env t -> (t : env) :> env - WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' - -> Append pre (t : env) :> t : Append pre env' - WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs - -> Append as (Append bs env) :> Append bs (Append as env) - WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs - -> as :> bs -> env1 :> env2 - -> Append as env1 :> Append bs env2 -deriving instance Show (env :> env') -infix 4 :> - -infixr 2 @> -(@>) :: env :> env' -> Idx env t -> Idx env' t -WId @> i = i -WSink @> i = IS i -WCopy _ @> IZ = IZ -WCopy w @> IS i = IS (w @> i) -WPop w @> i = w @> IS i -WThen w1 w2 @> i = w2 @> w1 @> i -WClosed @> i = case i of {} -WIdx j @> IZ = j -WIdx _ @> IS i = i -WPick SNil w @> i = WCopy w @> i -WPick (_ `SCons` _) _ @> IZ = IS IZ -WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i -WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i = - case splitIdx @(Append bs env) as i of - Left i' -> indexSinks bs (indexRaiseAbove @env as i') - Right i' -> case splitIdx @env bs i' of - Left j -> indexRaiseAbove @(Append as env) bs j - Right j -> indexSinks bs (indexSinks as j) -WStack @env1 @env2 as bs wlo whi @> i = - case splitIdx @env1 as i of - Left i' -> indexRaiseAbove @env2 bs (wlo @> i') - Right i' -> indexSinks bs (whi @> i') - -indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t -indexSinks SNil j = j -indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j) - -indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t -indexRaiseAbove = flip go - where - go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t - go IZ (_ `SCons` _) = IZ - go (IS i) (_ `SCons` as) = IS (go i as) - -infixr 3 .> -(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 -(.>) = flip WThen - -class KnownListSpine list where knownListSpine :: SList (Const ()) list -instance KnownListSpine '[] where knownListSpine = SNil -instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine - -wSinks' :: forall list env. KnownListSpine list => env :> Append list env -wSinks' = wSinks (knownListSpine :: SList (Const ()) list) - -wSinks :: forall env bs f. SList f bs -> env :> Append bs env -wSinks SNil = WId -wSinks (SCons _ spine) = WSink .> wSinks spine - -wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env' -wSinksAnd SNil w = w -wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w - -wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 -wCopies bs w = - let bs' = slistMap (\_ -> Const ()) bs - in WStack bs' bs' WId w - -wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env -wRaiseAbove SNil _ = WClosed -wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) - -wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2 -wPops SNil w = w -wPops (_ `SCons` bs) w = wPops bs (WPop w) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs deleted file mode 100644 index 7370df1..0000000 --- a/src/AST/Weaken/Auto.hs +++ /dev/null @@ -1,192 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} - -{-# LANGUAGE AllowAmbiguousTypes #-} - -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS_GHC -Wno-partial-type-signatures #-} -module AST.Weaken.Auto ( - autoWeak, - (&.), auto, auto1, - Layout(..), -) where - -import Data.Functor.Const -import Data.Kind (Constraint) -import GHC.OverloadedLabels -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import AST.Weaken -import Data -import Lemmas - - -type family Lookup name list where - Lookup name ('(name, x) : _) = x - Lookup name (_ : list) = Lookup name list - Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.") - - --- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. 'names' is the list of names that this --- layout /produces/. That is: for LPreW, it contains the target name. The --- 'names' list of a source layout must be a subset of the names list of the --- target layout (which cannot contain LPreW); this is checked with SubLayout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) - -- | Pre-weaken with a weakening - LPreW :: forall name1 name2 segments. - SegmentName name1 -> SegmentName name2 - -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments '[name2] (Lookup name1 segments) - (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) -infixr :++: - -instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where - fromLabel = LSeg (symbolSing @name) - -newtype SegmentName name = SegmentName (SSymbol name) - deriving (Show) - -instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where - fromLabel = SegmentName symbolSing - - -type family SubLayout names1 names2 where - SubLayout '[] _ = () :: Constraint - SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 -type family SubLayout' n ok names1 names2 where - SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") - SubLayout' _ True names1 names2 = SubLayout names1 names2 -type family Contains n names where - Contains _ '[] = False - Contains n (n : _) = True - Contains n (_ : names) = Contains n names - - -data SSegments (segments :: [(Symbol, [t])]) where - SSegNil :: SSegments '[] - SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) - -instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where - fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil - -auto :: KnownListSpine list => SList (Const ()) list -auto = knownListSpine - -auto1 :: SList (Const ()) '[t] -auto1 = Const () `SCons` SNil - -infixr &. -(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) -(&.) = ssegmentsAppend - where - ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) - ssegmentsAppend SSegNil l2 = l2 - ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2) - - --- | If the found segment is a TopSeg, returns Nothing. -segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) -segmentLookup = \segs name -> case go segs name of - Just ts -> ts - Nothing -> error $ "Segment not found: " ++ fromSSymbol name - where - go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) - go SSegNil _ = Nothing - go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = - case sameSymbol n name of - Just Refl -> - case go sseg name of - Nothing -> Just ts - Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name - Nothing -> - case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of - Refl -> go sseg name - -data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LinEnd :: LinLayout withPre segments '[] - LinApp :: SSymbol name -> LinLayout withPre segments env - -> LinLayout withPre segments (Append (Lookup name segments) env) - LinAppPreW :: SSymbol name1 -> SSymbol name2 - -> Lookup name1 segments :> Lookup name2 segments - -> LinLayout True segments env - -> LinLayout True segments (Append (Lookup name1 segments) env) - -linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2) -linLayoutAppend LinEnd lin = lin -linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) - | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 - = LinApp name (linLayoutAppend lin1 lin2) -linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) - | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 - = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) - -lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ _ seg) - | Refl <- lemAppendNil @seg - = LinApp name LinEnd -lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) - | Refl <- lemAppendNil @seg - = LinAppPreW name1 name2 w LinEnd - -preWeaken :: SSegments segments -> LinLayout True segments env - -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r -preWeaken _ LinEnd k = k WId LinEnd -preWeaken segs (LinApp name lin) k = - preWeaken segs lin $ \w lin' -> - k (wCopies (segmentLookup segs name) w) (LinApp name lin') -preWeaken segs (LinAppPreW name1 name2 weak lin) k = - preWeaken segs lin $ \w lin' -> - k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin') - -pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env - -> r -- Name was not found in source - -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r) - -> r -pullDown segs name@SSymbol linlayout kNotFound k = - case linlayout of - LinEnd -> kNotFound - LinApp n'@SSymbol lin - | Just Refl <- sameSymbol name n' -> k lin WId - | otherwise -> - pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w -> - k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) - .> wCopies (segmentLookup segs n') w) - -sortLinLayouts :: SSegments segments - -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 -sortLinLayouts _ LinEnd LinEnd = WId -sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) - | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) - | otherwise = - pullDown segs name2 lin1 - (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) - (\tail1' w -> - -- We've pulled down name2 in lin1 so that it's at the head; the - -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') - -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and - -- wCopies the name2 on top of that. - wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) -sortLinLayouts _ LinEnd LinApp{} = WClosed -sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" - -autoWeak :: SubLayout names1 names2 - => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 -autoWeak segs ly1 ly2 = - preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> - sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak |
