diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 127 | ||||
| -rw-r--r-- | src/AST/Bindings.hs | 20 | ||||
| -rw-r--r-- | src/AST/Count.hs | 956 | ||||
| -rw-r--r-- | src/AST/Env.hs | 84 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 204 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 287 | ||||
| -rw-r--r-- | src/AST/Sparse/Types.hs | 107 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 75 | ||||
| -rw-r--r-- | src/AST/Types.hs | 184 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 260 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 8 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 46 |
12 files changed, 1985 insertions, 373 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 67c5de7..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where @@ -26,35 +26,112 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPHere :: SAcPrj APHere a a SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - -- TODO: This SNat is rather useless, you always have an STy around too - SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b + SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b -- TODO: -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = AcIdx p a - AcIdx (APSnd p) (TPair a b) = AcIdx p b - AcIdx (APLeft p) (TEither a b) = AcIdx p a - AcIdx (APRight p) (TEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, array shape), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -acPrjTy :: SAcPrj p a b -> STy a -> STy b +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +-- AccumInfo TNil = TNil +-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +-- AccumInfo (TArr n t) = TArr n (AccumInfo t) +-- AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +-- PrimalInfo TNil = TNil +-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +-- PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 3d99afe..463586a 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -16,6 +16,7 @@ module AST.Bindings where import AST +import AST.Env import Data import Lemmas @@ -27,6 +28,10 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + mapBindings :: (forall env' t'. f env' t' -> g env' t') -> Bindings f env binds -> Bindings g env binds mapBindings _ BTop = BTop @@ -41,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) = let (b', w') = weakenBindings wf w b in (BPush b' (t, wf w' x), WCopy w') +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' weakenOver SNil w = w weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) @@ -62,3 +72,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t letBinds BTop = id letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + +collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' +collectBindings = \env -> fst . go env WId + where + go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) + go _ _ SETop = (BTop, WId) + go (ty `SCons` env) w (SEYesR sub) = + let (bs, w') = go env (WPop w) sub + in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') + go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index dc8ec72..ac8634e 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} @@ -10,17 +11,31 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} module AST.Count where -import Data.Functor.Const +import Data.Functor.Product +import Data.Some +import Data.Type.Equality import GHC.Generics (Generic, Generically(..)) +import Array import AST import AST.Env import Data +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where + -- | One of the two branches is taken + (<||>) :: a -> a -> a + -- | This code is executed many times + scaleMany :: a -> a + + data Count = Zero | One | Many deriving (Show, Eq, Ord) @@ -30,6 +45,10 @@ instance Semigroup Count where _ <> _ = Many instance Monoid Count where mempty = Zero +instance Occurrence Count where + (<||>) = max + scaleMany Zero = Zero + scaleMany _ = Many data Occ = Occ { _occLexical :: Count , _occRuntime :: Count } @@ -40,116 +59,863 @@ instance Show Occ where showsPrec d (Occ l r) = showParen (d > 10) $ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) +instance Occurrence Occ where + Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) + scaleMany (Occ l c) = Occ l (scaleMany c) + + +data Substruc t t' where + -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below + SsFull :: Substruc t t + SsNone :: Substruc t TNil + SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') + SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') + SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') + SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') + SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements + SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) + where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' +pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) + where SsArr' = SsArr +{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} + +instance Semigroup (Some (Substruc t)) where + Some SsFull <> _ = Some SsFull + _ <> Some SsFull = Some SsFull + Some SsNone <> s = s + s <> Some SsNone = s + Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) + Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) + Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) + Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) + Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) + Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where + mempty = Some SsNone + +instance TestEquality (Substruc t) where + testEquality SsFull s = isFull s + testEquality s SsFull = sym <$> isFull s + testEquality SsNone SsNone = Just Refl + testEquality SsNone _ = Nothing + testEquality _ SsNone = Nothing + testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) + | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +-- simplifySubstruc' s1 $ \s1' f1 -> +-- simplifySubstruc' s2 $ \s2' f2 -> +-- case (s1', s2') of +-- (SsFull, SsFull) -> +-- k SsFull (case (f1, f2) of +-- (ExMapId, ExMapId) -> ExMapId +-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where + MapEmpty '[] = '[] + MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where + OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! + OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where + Some OccEnd <> e = e + e <> Some OccEnd = e + Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where + mempty = Some OccEnd + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where + Some OccEnd <||> e = e + e <||> Some OccEnd = e + Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + + scaleMany (Some OccEnd) = Some OccEnd + scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s + | Some env' <- onehotOccEnv i v s + = Some (OccPush env' mempty SsNone) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) + +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd + +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i + +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) + | Some (Pair s' i') <- occEnvPrjS e i + = Some (Pair s' (IS i')) + +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of + _ | Just Refl <- testEquality topsbig topssmall -> ex + + (SsFull, SsFull) -> ex + (SsNone, SsNone) -> ex + (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" + (_, SsNone) -> + case typeOf ex of + STNil -> ex + _ -> use ex $ ENil ext + + (SsPair s1 s2, SsPair s1' s2') -> + eunPair ex $ \_ e1 e2 -> + EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) + (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex + (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex + + (SsEither s1 s2, SsEither s1' s2') + | STEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in ecase ex + (EInl ext (typeOf e2) e1) + (EInr ext (typeOf e1) e2) + (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex + (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex + + (SsLEither s1 s2, SsLEither s1' s2') + | STLEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in elcase ex + (ELNil ext (typeOf e1) (typeOf e2)) + (ELInl ext (typeOf e2) e1) + (ELInr ext (typeOf e1) e2) + (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex + (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex + + (SsMaybe s1, SsMaybe s1') + | STMaybe t1 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + in emaybe ex + (ENothing ext (typeOf e1)) + (EJust ext e1) + (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex + (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex + + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex + (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex + (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + + (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" + (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex + (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where + EMRest :: Bool -> EnvMask env + EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = - getConst . occCountGeneral - (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) - (\(Const o) -> Const o) - (\(Const o1) (Const o2) -> Const (o1 <||> o2)) - (\(Const o) -> Const (scaleMany o)) +occCount idx ex + | Some env <- occCountAll ex + = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) + where + fullOccEnv :: SList f env -> OccEnv () env env + fullOccEnv SNil = OccEnd + fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse +-- * k: is passed the actual environment usage of this expression, including +-- occurrence counts. The callback reconstructs a new expression in an +-- updated "response" environment. The response must be at least as large as +-- the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t + -> (forall env'. OccEnv Occ env env' + -- response OccEnv must be at least as large as the OccEnv returned above + -> (forall env''. OccEnv () env env'' -> Ex env'' t') + -> r) + -> r +occCountX initialS topexpr k = case topexpr of + EVar _ t i -> + withSome (onehotOccEnv i (Occ One One) s) $ \env -> + k env $ \env' -> + withSome (occEnvPrjS env' i) $ \(Pair s' i') -> + projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') + ELet _ rhs body -> + occCountX s body $ \envB mkbody -> + occEnvPop' envB $ \envB' s1 -> + occCountX s1 rhs $ \envR mkrhs -> + withSome (Some envB' <> Some envR) $ \env -> + k env $ \env' -> + ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) + EPair _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsPair' s1 s2 -> + occCountX s1 a $ \env1 mka -> + occCountX s2 b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPair ext (mka env') (mkb env') + EFst _ e -> + occCountX (SsPair s SsNone) e $ \env1 mke -> + k env1 $ \env' -> + EFst ext (mke env') + ESnd _ e -> + occCountX (SsPair SsNone s) e $ \env1 mke -> + k env1 $ \env' -> + ESnd ext (mke env') + ENil _ -> + case s of + SsFull -> k OccEnd (\_ -> ENil ext) + SsNone -> k OccEnd (\_ -> ENil ext) + EInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + EInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + EInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + EInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + ECase _ e a b -> + occCountX s a $ \env1' mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env1' $ \env1 s1 -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) + ENothing _ t -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EJust _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsMaybe s' -> + occCountX s' e $ \env1 mke -> + k env1 $ \env' -> + EJust ext (mke env') + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EMaybe _ a b e -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsMaybe s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') + ELNil _ t1 t2 -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + ELInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + ELInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELCase _ e a b c -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occCountX s c $ \env3' mkc -> + occEnvPop' env2' $ \env2 s1 -> + occEnvPop' env3' $ \env3 s2 -> + occCountX (SsLEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> + k env $ \env' -> + ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + + EConstArr _ n t x -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) + + EBuild _ n a b -> + case s of + SsNone -> + occCountX SsFull a $ \env1 mka -> + occCountX SsNone b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + use (EBuild ext n (mka env') $ + use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ + ENil ext) $ + ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX s' b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + EBuild ext n (mka env') $ + elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + + EFold1Inner _ commut a b c -> + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of + SsNone -> Some SsNone + SsArr' s' -> Some s' in + withSome (s1 <> s0) $ \sElt -> + occCountX sElt b $ \env2 mkb -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush env' () (SsPair sElt sElt))) + (mkb env') (mkc env') + + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e + + EUnit _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' s' -> + occCountX s' e $ \env mke -> + k env $ \env' -> + EUnit ext (mke env') + + EReplicate1Inner _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX (SsArr s') b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReplicate1Inner ext (mka env') (mkb env') + + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + + EFold1InnerD1 _ cm e1 e2 e3 -> + case s of + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm ef ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkebog env') (mked env') + + EConst _ t x -> + k OccEnd $ \_ -> + case s of + SsNone -> ENil ext + SsFull -> EConst ext t x + + EIdx0 _ e -> + occCountX (SsArr s) e $ \env1 mke -> + k env1 $ \env' -> + EIdx0 ext (mke env') + + EIdx1 _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX (SsArr s') a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx1 ext (mka env') (mkb env') + + EIdx _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + _ -> + occCountX (SsArr s) a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx ext (mka env') (mkb env') + + EShape _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX (SsArr SsNone) e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EShape ext (mke env') + EOp _ op e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EOp ext op (mke env') -data OccEnv env where - OccEnd :: OccEnv env -- not necessarily top! - OccPush :: OccEnv env -> Occ -> OccEnv (t : env) + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') -instance Semigroup (OccEnv env) where - OccEnd <> e = e - e <> OccEnd = e - OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') + ERecompute _ e -> + occCountX s e $ \env1 mke -> + k env1 $ \env' -> + ERecompute ext (mke env') -instance Monoid (OccEnv env) where - mempty = OccEnd + EWith _ t a b -> + case s of + SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with + occCountX SsNone b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + withSome (case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone) $ \s1' -> + occCountX s1' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ + ENil ext + SsPair sB sA -> + occCountX sB b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + let s1' = case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone in + withSome (Some sA <> s1') $ \sA' -> + occCountX sA' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ + EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) + SsFull -> occCountX (SsPair SsFull SsFull) topexpr k -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty + EAccum _ t p a sp b e -> + -- TODO: do better! + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull e $ \env3 mke -> + withSome (Some env1 <> Some env2) $ \env12 -> + withSome (Some env12 <> Some env3) $ \env -> + k env $ \env' -> + case s of {SsFull -> id; SsNone -> id} $ + EAccum ext t p (mka env') sp (mkb env') (mke env') -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') + EZero _ t e -> + occCountX (subZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EZero ext (applySubstrucM s t) (mke env') + where + subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) + subZeroInfo SsFull = SsFull + subZeroInfo SsNone = SsNone + subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) + subZeroInfo SsEither{} = error "Either is not a monoid" + subZeroInfo SsLEither{} = SsNone + subZeroInfo SsMaybe{} = SsNone + subZeroInfo (SsArr s') = SsArr (subZeroInfo s') + subZeroInfo SsAccum{} = error "Accum is not a monoid" -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) + EDeepZero _ t e -> + occCountX (subDeepZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EDeepZero ext (applySubstrucM s t) (mke env') + where + subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) + subDeepZeroInfo SsFull = SsFull + subDeepZeroInfo SsNone = SsNone + subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo SsEither{} = error "Either is not a monoid" + subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') + subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') + subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd + EPlus _ t a b -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPlus ext (applySubstrucM s t) (mka env') (mkb env') -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv + EOneHot _ t p a b -> + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> -- TODO: do better + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') -occCountGeneral :: forall r env t x. - (forall env'. Monoid (r env')) - => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot - -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env'. r env' -> r env' -> r env') -- ^ alternation - -> (forall env'. r env' -> r env') -- ^ scale-many - -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId + EError _ t msg -> + k OccEnd $ \_ -> EError ext (applySubstruc s t) msg where - go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' - go w = \case - EVar _ _ i -> onehot w i (Occ One One) - ELet _ rhs body -> re rhs <> re1 body - EPair _ a b -> re a <> re b - EFst _ e -> re e - ESnd _ e -> re e - ENil _ -> mempty - EInl _ _ e -> re e - EInr _ _ e -> re e - ECase _ e a b -> re e <> (re1 a `alter` re1 b) - ENothing _ _ -> mempty - EJust _ e -> re e - EMaybe _ a b e -> re a <> re1 b <> re e - EConstArr{} -> mempty - EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c - ESum1Inner _ e -> re e - EUnit _ e -> re e - EReplicate1Inner _ a b -> re a <> re b - EMaximum1Inner _ e -> re e - EMinimum1Inner _ e -> re e - EConst{} -> mempty - EIdx0 _ e -> re e - EIdx1 _ a b -> re a <> re b - EIdx _ a b -> re a <> re b - EShape _ e -> re e - EOp _ _ e -> re e - ECustom _ _ _ _ _ _ _ a b -> re a <> re b - EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e - EZero _ _ -> mempty - EPlus _ _ a b -> re a <> re b - EOneHot _ _ _ a b -> re a <> re b - EError{} -> mempty - where - re :: Monoid (r env') => Expr x env' t'' -> r env' - re = go w + s = simplifySubstruc (typeOf topexpr) initialS - re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' - re1 = unpush . go (WSink .> w) + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr' SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = - deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = - deleteUnused env occenv $ \sub -> +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = + deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = + deleteUnused env (Some occenv) $ \sub -> case count of Zero -> k (SENo sub) - _ -> k (SEYes sub) + _ -> k (SEYesR sub) unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t unsafeWeakenWithSubenv = \sub -> @@ -158,7 +924,7 @@ unsafeWeakenWithSubenv = \sub -> Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") where sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYes _) = Just IZ + sinkViaSubenv IZ (SEYesR _) = Just IZ sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index 4f34166..85faba3 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,59 +1,95 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Env where +import Data.Type.Equality + +import AST.Sparse import AST.Weaken +import CHAD.Types import Data -- | @env'@ is a subset of @env@: each element of @env@ is either included in -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') - SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} -subList :: SList f env -> Subenv env env' -> SList f env' +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: SList f env -> Subenv env env +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) -subenvNone :: SList f env -> Subenv env '[] +subenvNone :: SList f env -> Subenv' s env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 +-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r +-- subenvSplit SNil sub k = k SETop sub +-- subenvSplit (SCons _ list) (SENo sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SENo sub1) sub2 +-- subenvSplit (SCons _ list) (SEYes s sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SEYes s sub1) sub2 + +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub sinkWithSubenv (SENo sub) = sinkWithSubenv sub -wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 4f637f2..bbcfd9e 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -6,8 +6,9 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppTy, PrettyX(..)) where +module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) @@ -25,6 +26,7 @@ import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count +import AST.Sparse.Types import CHAD.Types import Data @@ -70,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -145,12 +148,45 @@ ppExpr' d val expr = case expr of EMaybe _ a b e -> do let STMaybe t = typeOf e - a' <- ppExpr' 11 val a + e' <- ppExpr' 0 val e + a' <- ppExpr' 0 val a name <- genNameIfUsedIn t IZ b b' <- ppExpr' 0 (Const name `SCons` val) b + return $ ppParen (d > 0) $ + align $ + group (flatAlt + (annotate AKey (ppString "case") <> ppX expr <+> e' + <> hardline <> annotate AKey (ppString "of")) + (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) + <> hardline + <> indent 2 + (ppString "Nothing" <+> ppString "->" <+> a' + <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') + + ELNil _ _ _ -> return (ppString "LNil") + + ELInl _ _ e -> do e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ - ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e'] + return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' + + ELInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' + + ELCase _ e a b c -> do + e' <- ppExpr' 0 val e + let STLEither t1 t2 = typeOf e + a' <- ppExpr' 11 val a + name1 <- genNameIfUsedIn t1 IZ b + b' <- ppExpr' 0 (Const name1 `SCons` val) b + name2 <- genNameIfUsedIn t2 IZ c + c' <- ppExpr' 0 (Const name2 `SCons` val) c + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "LNil" <+> ppString "->" <+> a' + <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' + <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -159,24 +195,31 @@ ppExpr' d val expr = case expr of a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b + let primName = ppString ("build" ++ intSubscript (fromSNat n)) return $ ppParen (d > 0) $ group $ flatAlt (hang 2 $ - annotate AHighlight (ppString "build") <> ppX expr <+> a' + annotate AHighlight primName <> ppX expr <+> a' <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" <> hardline <> e') - (ppApp (annotate AHighlight (ppString "build") <> ppX expr) [a', ppLam [ppString name] e']) + (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -199,6 +242,38 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] + + EFold1InnerD1 _ cm a b c -> do + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -254,6 +329,10 @@ ppExpr' d val expr = case expr of ,e1' ,e2'] + ERecompute _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] + EWith _ t e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 @@ -266,27 +345,35 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ _ prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] - EZero _ t -> return $ ppParen (d > 0) $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t + EZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - EPlus _ _ a b -> do + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] + ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b'] - EOneHot _ _ prj a b -> do + EOneHot _ t prj a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b'] + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b'] EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -319,14 +406,28 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") -ppAcPrj :: SAcPrj p a b -> String -ppAcPrj SAPHere = "@" -ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" -ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" -ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj -ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +ppAcPrj :: SMTy a -> SAcPrj p a b -> String +ppAcPrj _ SAPHere = "." +ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" +ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" +ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj +ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) + +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -354,28 +455,39 @@ operator OIDiv{} = (Infix, "`div`") operator OMod{} = (Infix, "`mod`") ppSTy :: Int -> STy t -> String -ppSTy d ty = ppTy d (unSTy ty) +ppSTy d ty = render $ ppSTy' d ty ppSTy' :: Int -> STy t -> Doc q -ppSTy' d ty = ppTy' d (unSTy ty) +ppSTy' _ STNil = ppString "1" +ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b +ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b +ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t +ppSTy' d (STArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t +ppSTy' _ (STScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + STBool -> "bool" +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t -ppTy :: Int -> Ty -> String -ppTy d ty = render $ ppTy' d ty +ppSMTy :: Int -> SMTy t -> String +ppSMTy d ty = render $ ppSMTy' d ty -ppTy' :: Int -> Ty -> Doc q -ppTy' _ TNil = ppString "1" -ppTy' d (TPair a b) = ppParen (d > 7) $ ppTy' 8 a <> ppString " * " <> ppTy' 8 b -ppTy' d (TEither a b) = ppParen (d > 6) $ ppTy' 7 a <> ppString " + " <> ppTy' 7 b -ppTy' d (TMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppTy' 11 t -ppTy' d (TArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromNat n)) <> ppString " " <> ppTy' 11 t -ppTy' _ (TScal sty) = ppString $ case sty of - TI32 -> "i32" - TI64 -> "i64" - TF32 -> "f32" - TF64 -> "f64" - TBool -> "bool" -ppTy' d (TAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppTy' 11 t +ppSMTy' :: Int -> SMTy t -> Doc q +ppSMTy' _ SMTNil = ppString "1" +ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b +ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b +ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t +ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t +ppSMTy' _ (SMTScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" ppString :: String -> Doc x ppString = fromString diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs new file mode 100644 index 0000000..2a29799 --- /dev/null +++ b/src/AST/Sparse.hs @@ -0,0 +1,287 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE RankNTypes #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where + +import Data.Type.Equality + +import AST +import AST.Sparse.Types +import Data (SBool(..)) + + +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + + +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This simplifies the sparsePlusS code. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. TODO can this be improved? +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs new file mode 100644 index 0000000..10cac4e --- /dev/null +++ b/src/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AST.Sparse.Types where + +import AST.Types + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index dcba1ad..267dd87 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -22,16 +22,26 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) EMaybe x a b e -> let STMaybe t1 = typeOf e in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) + ELCase x e a b c -> + let STLEither t1 t2 = typeOf e + in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -41,13 +51,19 @@ splitLets' = \sub -> \case EInr x t e -> EInr x t (splitLets' sub e) ENothing x t -> ENothing x t EJust x e -> EJust x (splitLets' sub e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (splitLets' sub e) + ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) @@ -55,9 +71,11 @@ splitLets' = \sub -> \case EShape x e -> EShape x (splitLets' sub e) EOp x op e -> EOp x op (splitLets' sub e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) + ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) - EZero x t -> EZero x t + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) + EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s @@ -81,15 +99,42 @@ splitLets' = \sub -> \case -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') + bs1 = fst (weakenBindingsE WSink bs1') (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] @@ -117,6 +162,7 @@ split typ = case typ of STPair{} -> splitRec (EVar ext typ IZ) typ STNil -> other STEither{} -> other + STLEither{} -> other STMaybe{} -> other STArr{} -> other STScal{} -> other @@ -127,18 +173,19 @@ split typ = case typ of splitRec :: forall env t. Ex env t -> STy t -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs = \case +splitRec rhs typ = case typ of STNil -> (PNil, BTop) STPair (a :: STy a) (b :: STy b) | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> let (p1, bs1) = splitRec (EFst ext rhs) a (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) - t@STEither{} -> other t - t@STMaybe{} -> other t - t@STArr{} -> other t - t@STScal{} -> other t - t@STAccum{} -> other t + STEither{} -> other + STLEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other where - other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) - other t = (Point t IZ, BPush BTop (t, rhs)) + other :: (Pointers (t : env) t, Bindings Ex env '[t]) + other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 217b2f5..4ddcb50 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -1,64 +1,110 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module AST.Types where import Data.Int (Int32, Int64) +import Data.GADT.Compare import Data.GADT.Show import Data.Kind (Type) -import Data.Some import Data.Type.Equality import Data -data Ty +type data Ty = TNil | TPair Ty Ty | TEither Ty Ty + | TLEither Ty Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy - | TAccum Ty -- ^ the accumulator contains D2 of this type - deriving (Show, Eq, Ord) + | TAccum Ty -- ^ contained type must be a monoid type -data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool - deriving (Show, Eq, Ord) +type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool type STy :: Ty -> Type data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) + STLEither :: STy a -> STy b -> STy (TLEither a b) STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) - STAccum :: STy t -> STy (TAccum t) + STAccum :: SMTy t -> STy (TAccum t) deriving instance Show (STy t) -instance TestEquality STy where - testEquality STNil STNil = Just Refl - testEquality STNil _ = Nothing - testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality STPair{} _ = Nothing - testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality STEither{} _ = Nothing - testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl - testEquality STMaybe{} _ = Nothing - testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality STArr{} _ = Nothing - testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl - testEquality STScal{} _ = Nothing - testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl - testEquality STAccum{} _ = Nothing +instance GCompare STy where + gcompare = \cases + STNil STNil -> GEQ + STNil _ -> GLT ; _ STNil -> GGT + (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STPair{} _ -> GLT ; _ STPair{} -> GGT + (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STEither{} _ -> GLT ; _ STEither{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STLEither{} _ -> GLT ; _ STLEither{} -> GGT + (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') + STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT + (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + STArr{} _ -> GLT ; _ STArr{} -> GGT + (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') + STScal{} _ -> GLT ; _ STScal{} -> GGT + (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') + -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT +instance TestEquality STy where testEquality = geq +instance GEq STy where geq = defaultGeq instance GShow STy where gshowsPrec = defaultGshowsPrec +-- | Monoid types +type SMTy :: Ty -> Type +data SMTy t where + SMTNil :: SMTy TNil + SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) + SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) + SMTMaybe :: SMTy a -> SMTy (TMaybe a) + SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) + SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) +deriving instance Show (SMTy t) + +instance GCompare SMTy where + gcompare = \cases + SMTNil SMTNil -> GEQ + SMTNil _ -> GLT ; _ SMTNil -> GGT + (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT + (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT + (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') + SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT + (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT + (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') + -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + +instance TestEquality SMTy where testEquality = geq +instance GEq SMTy where geq = defaultGeq +instance GShow SMTy where gshowsPrec = defaultGshowsPrec + +fromSMTy :: SMTy t -> STy t +fromSMTy = \case + SMTNil -> STNil + SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) + SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) + SMTMaybe t -> STMaybe (fromSMTy t) + SMTArr n t -> STArr n (fromSMTy t) + SMTScal sty -> STScal sty + data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 @@ -67,14 +113,21 @@ data SScalTy t where STBool :: SScalTy TBool deriving instance Show (SScalTy t) -instance TestEquality SScalTy where - testEquality STI32 STI32 = Just Refl - testEquality STI64 STI64 = Just Refl - testEquality STF32 STF32 = Just Refl - testEquality STF64 STF64 = Just Refl - testEquality STBool STBool = Just Refl - testEquality _ _ = Nothing +instance GCompare SScalTy where + gcompare = \cases + STI32 STI32 -> GEQ + STI32 _ -> GLT ; _ STI32 -> GGT + STI64 STI64 -> GEQ + STI64 _ -> GLT ; _ STI64 -> GGT + STF32 STF32 -> GEQ + STF32 _ -> GLT ; _ STF32 -> GGT + STF64 STF64 -> GEQ + STF64 _ -> GLT ; _ STF64 -> GGT + STBool STBool -> GEQ + -- STBool _ -> GLT ; _ STBool -> GGT +instance TestEquality SScalTy where testEquality = geq +instance GEq SScalTy where geq = defaultGeq instance GShow SScalTy where gshowsPrec = defaultGshowsPrec scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) @@ -89,50 +142,6 @@ type TIx = TScal TI64 tIx :: STy TIx tIx = STScal STI64 -unSTy :: STy t -> Ty -unSTy = \case - STNil -> TNil - STPair a b -> TPair (unSTy a) (unSTy b) - STEither a b -> TEither (unSTy a) (unSTy b) - STMaybe t -> TMaybe (unSTy t) - STArr n t -> TArr (unSNat n) (unSTy t) - STScal t -> TScal (unSScalTy t) - STAccum t -> TAccum (unSTy t) - -unSEnv :: SList STy env -> [Ty] -unSEnv SNil = [] -unSEnv (SCons t l) = unSTy t : unSEnv l - -unSScalTy :: SScalTy t -> ScalTy -unSScalTy = \case - STI32 -> TI32 - STI64 -> TI64 - STF32 -> TF32 - STF64 -> TF64 - STBool -> TBool - -reSTy :: Ty -> Some STy -reSTy = \case - TNil -> Some STNil - TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b' - TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b' - TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t' - TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t' - TScal t | Some t' <- reSScalTy t -> Some $ STScal t' - TAccum t | Some t' <- reSTy t -> Some $ STAccum t' - -reSEnv :: [Ty] -> Some (SList STy) -reSEnv [] = Some SNil -reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env) - -reSScalTy :: ScalTy -> Some SScalTy -reSScalTy = \case - TI32 -> Some STI32 - TI64 -> Some STI64 - TF32 -> Some STF32 - TF64 -> Some STF64 - TBool -> Some STBool - type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 @@ -161,15 +170,26 @@ type family ScalIsIntegral t where ScalIsIntegral TF64 = False ScalIsIntegral TBool = False --- | Returns true for arrays /and/ accumulators; -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = True +-- | Returns true for arrays /and/ accumulators. +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True type family Tup env where Tup '[] = TNil diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 0da1afc..1712ba5 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -1,17 +1,22 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where +module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import AST -import CHAD.Types +import AST.Sparse.Types import Data +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case - EZero _ t -> zero t + EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -27,14 +32,23 @@ unMonoid = \case ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unMonoid e) + ELInr _ t e -> ELInr ext t (unMonoid e) + ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) @@ -42,96 +56,200 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s -zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) -zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) -zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) -zero (STArr n t) = ENothing ext (STArr n (d2 t)) -zero (STScal t) = case t of - STI32 -> ENil ext - STI64 -> ENil ext +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +-- don't destroy the effects! +zero SMTNil e = ELet ext e $ ENil ext +zero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) +zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e +zero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 - STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" -plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = - let t = STPair (d2 t1) (d2 t2) - in plusSparse t a b $ +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t +-- don't destroy the effects! +plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext +plus (SMTPair t1 t2) a b = + let t = STPair (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) (EFst ext (EVar ext t IZ))) (plus t2 (ESnd ext (EVar ext t (IS IZ))) (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = - let t = STEither (d2 t1) (d2 t2) - in plusSparse t a b $ - ECase ext (EVar ext t (IS IZ)) - (ECase ext (EVar ext t (IS IZ)) - (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) +plus (SMTLEither t1 t2) a b = + let t = STLEither (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + ELCase ext (EVar ext t (IS IZ)) + (EVar ext t IZ) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) (EError ext t "plus l+r")) - (ECase ext (EVar ext t (IS IZ)) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) (EError ext t "plus r+l") - (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus (STMaybe t) a b = - plusSparse (d2 t) a b $ - plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) -plus (STArr n t) a b = - plusSparse (STArr n (d2 t)) a b $ - eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ) - (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) - (EVar ext (STArr n (d2 t)) (IS IZ)) - (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) - (EVar ext (STArr n (d2 t)) (IS IZ)) - (EVar ext (STArr n (d2 t)) IZ))) -plus (STScal t) a b = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EOp ext (OAdd STF32) (EPair ext a b) - STF64 -> EOp ext (OAdd STF64) (EPair ext a b) - STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" - -plusSparse :: STy a - -> Ex env (TMaybe a) -> Ex env (TMaybe a) - -> Ex (a : a : env) a - -> Ex env (TMaybe a) -plusSparse t a b adder = + (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) +plus (SMTMaybe t) a b = ELet ext b $ EMaybe ext - (EVar ext (STMaybe t) IZ) + (EVar ext (STMaybe (fromSMTy t)) IZ) (EJust ext (EMaybe ext - (EVar ext t IZ) - (weakenExpr (WCopy (WCopy WSink)) adder) - (EVar ext (STMaybe t) (IS IZ)))) + (EVar ext (fromSMTy t) IZ) + (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) (weakenExpr WSink a) +plus (SMTArr _ t) a b = + ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + a b +plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> arg + (_, SAPHere) -> + ELet ext arg $ + EVar ext (fromSMTy typ) IZ - (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) - (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) + (SMTPair t1 t2, SAPFst prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t1 in + EPair ext (EVar ext toh IZ) + (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) + + (SMTPair t1 t2, SAPSnd prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t2 in + EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) + (EVar ext toh IZ) - (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) - (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) + (SMTLEither t1 t2, SAPLeft prj) -> + ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + (SMTLEither t1 t2, SAPRight prj) -> + ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) - (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) + (SMTMaybe t1, SAPJust prj) -> + EJust ext (onehot t1 prj idx arg) - (STArr n t1, SAPArrIdx prj _) -> + (SMTArr n t1, SAPArrIdx prj) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ - EJust ext $ - EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (zero t1) + EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ + zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index d882e28..f0820b8 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -19,6 +19,7 @@ module AST.Weaken (module AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const +import Data.GADT.Compare import Data.Kind (Type) import Data @@ -31,6 +32,11 @@ data Idx env t where IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) +instance GEq (Idx env) where + geq IZ IZ = Just Refl + geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl + geq _ _ = Nothing + splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ @@ -123,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 6752c24..7370df1 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -11,6 +11,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} @@ -23,6 +24,7 @@ module AST.Weaken.Auto ( ) where import Data.Functor.Const +import Data.Kind (Constraint) import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) @@ -39,18 +41,21 @@ type family Lookup name list where -- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) -- | Pre-weaken with a weakening LPreW :: forall name1 name2 segments. SegmentName name1 -> SegmentName name2 -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments (Lookup name1 segments) - (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where fromLabel = LSeg (symbolSing @name) newtype SegmentName name = SegmentName (SSymbol name) @@ -60,11 +65,23 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh fromLabel = SegmentName symbolSing +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil auto :: KnownListSpine list => SList (Const ()) list @@ -74,7 +91,7 @@ auto1 :: SList (Const ()) '[t] auto1 = Const () `SCons` SNil infixr &. -(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) (&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) @@ -118,12 +135,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ seg) +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinAppPreW name1 name2 w LinEnd @@ -151,8 +168,7 @@ pullDown segs name@SSymbol linlayout kNotFound k = k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) -sortLinLayouts :: forall segments env1 env2. - SSegments segments +sortLinLayouts :: SSegments segments -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) @@ -169,8 +185,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" -autoWeak :: forall segments env1 env2. - SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 autoWeak segs ly1 ly2 = preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak |
