diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 60 | ||||
| -rw-r--r-- | src/AST/Count.hs | 3 | ||||
| -rw-r--r-- | src/AST/Env.hs | 24 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 21 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 308 | ||||
| -rw-r--r-- | src/AST/Sparse/Types.hs | 107 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 3 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 118 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 2 |
9 files changed, 391 insertions, 255 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 1101cc0..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,6 +1,8 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where @@ -32,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) - AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) - AcIdx (APLeft p) (TLEither a b) = AcIdx p a - AcIdx (APRight p) (TLEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, shapes info), recursive info) +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t @@ -72,6 +92,24 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 03a36f6..ca4d7ab 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -134,8 +134,9 @@ occCountGeneral onehot unpush alter many = go WId ECustom _ _ _ _ _ _ _ a b -> re a <> re b ERecompute _ e -> re e EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e + EAccum _ _ _ a _ b e -> re a <> re b <> re e EZero _ _ e -> re e + EDeepZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b EError{} -> mempty diff --git a/src/AST/Env.hs b/src/AST/Env.hs index bc2b9e0..422f0f7 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -4,6 +4,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Env where @@ -12,6 +13,7 @@ import Data.Type.Equality import AST.Sparse import AST.Weaken +import CHAD.Types import Data @@ -38,18 +40,18 @@ subList SNil SETop = SNil subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: IsSubType s => SList f env -> Subenv' s env env +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes subtFull (subenvAll env) +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) subenvNone :: SList f env -> Subenv' s env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: IsSubType s => SList f env -> Idx env t -> Subenv' s env '[t] -subenvOnehot (SCons _ env) IZ = SEYes subtFull (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 subenvCompose SETop SETop = SETop @@ -71,3 +73,13 @@ wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env wUndoSubenv SETop = WId wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 41da656..fef9686 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -25,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count +import AST.Sparse.Types import CHAD.Types import Data @@ -304,18 +305,24 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] EZero _ t e1 -> do e1' <- ppExpr' 11 val e1 return $ ppParen (d > 0) $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b @@ -368,6 +375,16 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 09dbc70..93258b7 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -1,93 +1,74 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# OPTIONS_GHC -fmax-pmcheck-models=60 #-} -module AST.Sparse where +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where -import Data.Kind (Constraint, Type) import Data.Type.Equality import AST +import AST.Sparse.Types +import Data (SBool(..)) -data Sparse t t' where - SpDense :: Sparse t t - SpSparse :: Sparse t t' -> Sparse t (TMaybe t') - SpAbsent :: Sparse t TNil - - SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') - SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') - SpLeft :: Sparse a a' -> Sparse (TLEither a b) a' - SpRight :: Sparse b b' -> Sparse (TLEither a b) b' - SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') - SpJust :: Sparse t t' -> Sparse (TMaybe t) t' - SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') -deriving instance Show (Sparse t t') - -applySparse :: Sparse t t' -> STy t -> STy t' -applySparse SpDense t = t -applySparse (SpSparse s) t = STMaybe (applySparse s t) -applySparse SpAbsent _ = STNil -applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) -applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) -applySparse (SpLeft s) (STLEither t1 _) = applySparse s t1 -applySparse (SpRight s) (STLEither _ t2) = applySparse s t2 -applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) -applySparse (SpJust s) (STMaybe t) = applySparse s t -applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) - - -class IsSubType s where - type IsSubTypeSubject s (f :: k -> Type) :: Constraint - subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c - subtFull :: s a a - -instance IsSubType (:~:) where - type IsSubTypeSubject (:~:) f = () - subtApply = gcastWith - subtTrans = trans - subtFull = Refl +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 -instance IsSubType Sparse where - type IsSubTypeSubject Sparse f = f ~ STy - subtApply = applySparse - subtTrans SpDense s = s - subtTrans s SpDense = s - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1 _) (SpLeft s2) = SpLeft (subtTrans s1 s2) - subtTrans (SpLEither _ s1) (SpRight s2) = SpRight (subtTrans s1 s2) - subtTrans (SpLeft s1) s2 = SpLeft (subtTrans s1 s2) - subtTrans (SpRight s1) s2 = SpRight (subtTrans s1 s2) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpSparse s1) (SpJust s2) = subtTrans s1 s2 - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpJust s2) = SpJust (subtTrans s1 s2) - subtTrans (SpJust s1) s2 = SpJust (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) - subtFull = SpDense - - -data SBool b where - SF :: SBool False - ST :: SBool True -deriving instance Show (SBool b) data Injection sp a b where -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that -- 'sparsePlusS' can provide injections even if the caller doesn't require - -- them. This eliminates pointless checks. + -- them. This simplifies the sparsePlusS code. Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b Noinj :: Injection False a b @@ -104,8 +85,11 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g) withInj2 Noinj _ _ = Noinj withInj2 _ Noinj _ = Noinj +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + -- | This function produces quadratically-sized code in the presence of nested --- dynamic sparsity. しょうがない。 +-- dynamic sparsity. TODO can this be improved? sparsePlusS :: SBool inj1 -> SBool inj2 -> SMTy t -> Sparse t t1 -> Sparse t t2 @@ -115,16 +99,17 @@ sparsePlusS -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) -> r) -> r --- nil override -sparsePlusS _ _ SMTNil _ _ k = k SpAbsent (Inj $ \_ -> ENil ext) (Inj $ \_ -> ENil ext) (\_ _ -> ENil ext) +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) -- simplifications sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> - k sp3 (withInj minj1 $ \inj1 -> \_ -> inj1 (ENil ext)) minj2 (\_ b -> plus (ENil ext) b) + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> - k sp3 minj1 (withInj minj2 $ \inj2 -> \_ -> inj2 (ENil ext)) (\a _ -> plus a (ENil ext)) + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = let ta = applySparse sp1 (fromSMTy t) in @@ -176,16 +161,25 @@ sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t -- TODO: sparse of Just is just Maybe -- dense plus -sparsePlusS _ _ t SpDense SpDense k = k SpDense (Inj id) (Inj id) (\a b -> EPlus ext t a b) +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) -- handle absents -sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\_ b -> b) -sparsePlusS ST _ t SpAbsent sp2 k = - k (SpSparse sp2) (Inj $ \_ -> ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\_ b -> EJust ext b) +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) -sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a _ -> a) -sparsePlusS _ ST t sp1 SpAbsent k = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \_ -> ENothing ext (applySparse sp1 (fromSMTy t))) (\a _ -> EJust ext a) +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) -- double sparse yields sparse sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = @@ -239,8 +233,6 @@ sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = eunPair x1 $ \w1 x1a x1b -> eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) -sparsePlusS req1 req2 t sp1@SpPair{} SpDense k = sparsePlusS req1 req2 t sp1 (SpPair SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpPair{} k = sparsePlusS req1 req2 t (SpPair SpDense SpDense) sp2 k -- coproducts sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = @@ -268,107 +260,6 @@ sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k (inr (inj13b (evar IZ))) (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") (inr (plusb (evar (IS IZ)) (evar IZ))))) -sparsePlusS req1 req2 t sp1@SpLEither{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpLEither{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k - --- coproducts with partially known arguments: if we have a non-nil --- always-present coproduct argument, the result is dense, otherwise we --- introduce sparsity -sparsePlusS _ SF (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = - sparsePlusS ST SF ta sp1a sp2a $ \sp3a (Inj inj13a) _ plusa -> - k (SpLeft sp3a) - (Inj inj13a) - Noinj - (\x1 x2 -> - elet x1 $ - elcase (weakenExpr WSink x2) - (inj13a (evar IZ)) - (plusa (evar (IS IZ)) (evar IZ)) - (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) - -sparsePlusS _ ST (SMTLEither ta _) (SpLeft sp1a) (SpLEither sp2a _) k = - sparsePlusS ST ST ta sp1a sp2a $ \sp3a (Inj inj13a) (Inj inj23a) plusa -> - k (SpSparse (SpLeft sp3a)) - (Inj $ \x1 -> EJust ext (inj13a x1)) - (Inj $ \x2 -> - elcase x2 - (ENothing ext (applySparse sp3a (fromSMTy ta))) - (EJust ext (inj23a (evar IZ))) - (EError ext (STMaybe (applySparse sp3a (fromSMTy ta))) "plusSi2 !ll+lr")) - (\x1 x2 -> - elet x1 $ - EJust ext $ - elcase (weakenExpr WSink x2) - (inj13a (evar IZ)) - (plusa (evar (IS IZ)) (evar IZ)) - (EError ext (applySparse sp3a (fromSMTy ta)) "plusS !ll+lr")) - -sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpLeft{} k = - sparsePlusS req2 req1 t sp2 sp1 $ \sp3a inj13a inj23a plusa -> k sp3a inj23a inj13a (flip plusa) -sparsePlusS req1 req2 t sp1@SpLeft{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpLeft{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k - -sparsePlusS _ SF (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = - sparsePlusS ST SF tb sp1b sp2b $ \sp3b (Inj inj13b) _ plusb -> - k (SpRight sp3b) - (Inj inj13b) - Noinj - (\x1 x2 -> - elet x1 $ - elcase (weakenExpr WSink x2) - (inj13b (evar IZ)) - (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") - (plusb (evar (IS IZ)) (evar IZ))) - -sparsePlusS _ ST (SMTLEither _ tb) (SpRight sp1b) (SpLEither _ sp2b) k = - sparsePlusS ST ST tb sp1b sp2b $ \sp3b (Inj inj13b) (Inj inj23b) plusb -> - k (SpSparse (SpRight sp3b)) - (Inj $ \x1 -> EJust ext (inj13b x1)) - (Inj $ \x2 -> - elcase x2 - (ENothing ext (applySparse sp3b (fromSMTy tb))) - (EError ext (STMaybe (applySparse sp3b (fromSMTy tb))) "plusSi2 !lr+ll") - (EJust ext (inj23b (evar IZ)))) - (\x1 x2 -> - elet x1 $ - EJust ext $ - elcase (weakenExpr WSink x2) - (inj13b (evar IZ)) - (EError ext (applySparse sp3b (fromSMTy tb)) "plusS !lr+ll") - (plusb (evar (IS IZ)) (evar IZ))) - -sparsePlusS req1 req2 t sp1@SpLEither{} sp2@SpRight{} k = - sparsePlusS req2 req1 t sp2 sp1 $ \sp3b inj13b inj23b plusb -> k sp3b inj23b inj13b (flip plusb) -sparsePlusS req1 req2 t sp1@SpRight{} SpDense k = sparsePlusS req1 req2 t sp1 (SpLEither SpDense SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpRight{} k = sparsePlusS req1 req2 t (SpLEither SpDense SpDense) sp2 k - --- dense same-branch coproducts simply recurse -sparsePlusS req1 req2 (SMTLEither ta _) (SpLeft sp1) (SpLeft sp2) k = - sparsePlusS req1 req2 ta sp1 sp2 $ \sp3 inj1 inj2 plus -> - k (SpLeft sp3) inj1 inj2 plus -sparsePlusS req1 req2 (SMTLEither _ tb) (SpRight sp1) (SpRight sp2) k = - sparsePlusS req1 req2 tb sp1 sp2 $ \sp3 inj1 inj2 plus -> - k (SpRight sp3) inj1 inj2 plus - --- dense, mismatched coproducts are valid as long as we don't actually invoke --- plus at runtime (injections are fine) -sparsePlusS SF SF _ SpLeft{} SpRight{} k = - k SpAbsent Noinj Noinj (\_ _ -> EError ext STNil "plusS !ll+!lr") -sparsePlusS SF ST (SMTLEither _ tb) SpLeft{} (SpRight sp2) k = - k (SpRight sp2) Noinj (Inj id) - (\_ _ -> EError ext (applySparse sp2 (fromSMTy tb)) "plusS !ll+?lr") -sparsePlusS ST SF (SMTLEither ta _) (SpLeft sp1) SpRight{} k = - k (SpLeft sp1) (Inj id) Noinj - (\_ _ -> EError ext (applySparse sp1 (fromSMTy ta)) "plusS !lr+?ll") -sparsePlusS ST ST (SMTLEither ta tb) (SpLeft sp1) (SpRight sp2) k = - -- note: we know that this cannot be ELNil, but the returned 'Sparse' unfortunately claims to allow it. - k (SpLEither sp1 sp2) - (Inj $ \a -> ELInl ext (applySparse sp2 (fromSMTy tb)) a) - (Inj $ \b -> ELInr ext (applySparse sp1 (fromSMTy ta)) b) - (\_ _ -> EError ext (STLEither (applySparse sp1 (fromSMTy ta)) (applySparse sp2 (fromSMTy tb))) "plusS ?ll+?lr") - -sparsePlusS req1 req2 t sp1@SpRight{} sp2@SpLeft{} k = -- the errors are not flipped, but eh - sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj1 inj2 plus -> k sp3 inj2 inj1 (flip plus) -- maybe sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = @@ -385,42 +276,6 @@ sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = (emaybe (evar (IS IZ)) (EJust ext (inj1 (evar IZ))) (EJust ext (plus (evar (IS IZ)) (evar IZ))))) -sparsePlusS req1 req2 t sp1@SpMaybe{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpMaybe{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k - --- maybe with partially known arguments: if we have an always-present Just --- argument, the result is dense, otherwise we introduce sparsity by weakening --- to SpMaybe -sparsePlusS _ SF (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = - sparsePlusS ST SF t sp1 sp2 $ \sp3 (Inj inj1) _ plus -> - k (SpJust sp3) - (Inj inj1) - Noinj - (\a b -> - elet a $ - emaybe (weakenExpr WSink b) - (inj1 (evar IZ)) - (plus (evar (IS IZ)) (evar IZ))) -sparsePlusS _ ST (SMTMaybe t) (SpJust sp1) (SpMaybe sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpMaybe sp3) - (Inj $ \a -> EJust ext (inj1 a)) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet a $ - emaybe (weakenExpr WSink b) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ)))) - -sparsePlusS req1 req2 t sp1@SpMaybe{} sp2@SpJust{} k = - sparsePlusS req2 req1 t sp2 sp1 $ \sp3 inj2 inj1 plus -> k sp3 inj1 inj2 (flip plus) -sparsePlusS req1 req2 t sp1@SpJust{} SpDense k = sparsePlusS req1 req2 t sp1 (SpMaybe SpDense) k -sparsePlusS req1 req2 t SpDense sp2@SpJust{} k = sparsePlusS req1 req2 t (SpMaybe SpDense) sp2 k - --- dense same-branch maybes simply recurse -sparsePlusS req1 req2 (SMTMaybe t) (SpJust sp1) (SpJust sp2) k = - sparsePlusS req1 req2 t sp1 sp2 $ \sp3 inj1 inj2 plus -> - k (SpJust sp3) inj1 inj2 plus -- dense array cotangents simply recurse sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = @@ -430,5 +285,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) -sparsePlusS req1 req2 t (SpArr sp1) SpDense k = sparsePlusS req1 req2 t (SpArr sp1) (SpArr SpDense) k -sparsePlusS req1 req2 t SpDense (SpArr sp2) k = sparsePlusS req1 req2 t (SpArr SpDense) (SpArr sp2) k + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs new file mode 100644 index 0000000..10cac4e --- /dev/null +++ b/src/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AST.Sparse.Types where + +import AST.Types + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 3c353d4..dcaf82f 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -63,8 +63,9 @@ splitLets' = \sub -> \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ac4d733..ef01bf8 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -1,18 +1,22 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where +module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import AST +import AST.Sparse.Types import Data --- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them --- into their concrete implementations. +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -49,11 +53,14 @@ unMonoid = \case ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t -zero SMTNil _ = ENil ext +zero SMTNil e = elet e $ ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) @@ -66,8 +73,30 @@ zero (SMTScal t) _ = case t of STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t -plus SMTNil _ _ = ENil ext +-- don't destroy the effects! +plus SMTNil a b = elet a $ elet (weakenExpr WSink b) $ ENil ext plus (SMTPair t1 t2) a b = let t = STPair (fromSMTy t1) (fromSMTy t2) in ELet ext a $ @@ -105,7 +134,7 @@ plus (SMTArr _ t) a b = a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ @@ -143,3 +172,78 @@ onehot typ topprj idx arg = case (typ, topprj) of (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 6752c24..c6efe37 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil auto :: KnownListSpine list => SList (Const ()) list |
