diff options
Diffstat (limited to 'src/AST/Sparse.hs')
| -rw-r--r-- | src/AST/Sparse.hs | 287 |
1 files changed, 0 insertions, 287 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs deleted file mode 100644 index 2a29799..0000000 --- a/src/AST/Sparse.hs +++ /dev/null @@ -1,287 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE RankNTypes #-} - -{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where - -import Data.Type.Equality - -import AST -import AST.Sparse.Types -import Data (SBool(..)) - - -sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' -sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext -sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 -sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh -sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = - eunPair e1 $ \w1 e1a e1b -> - eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> - EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) - (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) -sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = - elet e2 $ - elcase (weakenExpr WSink e1) - (evar IZ) - (elcase (evar (IS IZ)) - (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) - (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) - (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) - (elcase (evar (IS IZ)) - (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) - (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") - (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) -sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = - elet e2 $ - emaybe (weakenExpr WSink e1) - (evar IZ) - (emaybe (evar (IS IZ)) - (EJust ext (evar IZ)) - (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) -sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 -sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 - - -cheapZero :: SMTy t -> Maybe (forall env. Ex env t) -cheapZero SMTNil = Just (ENil ext) -cheapZero (SMTPair t1 t2) - | Just e1 <- cheapZero t1 - , Just e2 <- cheapZero t2 - = Just (EPair ext e1 e2) - | otherwise - = Nothing -cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) -cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) -cheapZero SMTArr{} = Nothing -cheapZero (SMTScal t) = case t of - STI32 -> Just (EConst ext t 0) - STI64 -> Just (EConst ext t 0) - STF32 -> Just (EConst ext t 0.0) - STF64 -> Just (EConst ext t 0.0) - - -data Injection sp a b where - -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that - -- 'sparsePlusS' can provide injections even if the caller doesn't require - -- them. This simplifies the sparsePlusS code. - Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b - Noinj :: Injection False a b - -withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' -withInj (Inj f) k = Inj (k f) -withInj Noinj _ = Noinj - -withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 - -> ((forall e. Ex e a1 -> Ex e b1) - -> (forall e. Ex e a2 -> Ex e b2) - -> (forall e'. Ex e' a' -> Ex e' b')) - -> Injection sp a' b' -withInj2 (Inj f) (Inj g) k = Inj (k f g) -withInj2 Noinj _ _ = Noinj -withInj2 _ Noinj _ = Noinj - --- | This function produces quadratically-sized code in the presence of nested --- dynamic sparsity. TODO can this be improved? -sparsePlusS - :: SBool inj1 -> SBool inj2 - -> SMTy t -> Sparse t t1 -> Sparse t t2 - -> (forall t3. Sparse t t3 - -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) - -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) - -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) - -> r) - -> r --- nil override (but don't destroy effects!) -sparsePlusS _ _ SMTNil _ _ k = - k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) - --- simplifications -sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = - sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> - k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) -sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = - sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> - k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) - -sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = - let ta = applySparse sp1 (fromSMTy t) in - sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) - minj2 - (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = - let tb = applySparse sp2 (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) - (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) - -sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = - let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in - sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - minj2 - (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = - let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - -sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = - let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in - sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) - minj2 - (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = - let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) - (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) -sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k -sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k - --- TODO: sparse of Just is just Maybe - --- dense plus -sparsePlusS _ _ t sp1 sp2 k - | Just Refl <- isDense t sp1 - , Just Refl <- isDense t sp2 - = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) - --- handle absents -sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) -sparsePlusS ST _ t SpAbsent sp2 k - | Just zero2 <- cheapZero (applySparse sp2 t) = - k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) - | otherwise = - k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) - -sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) -sparsePlusS _ ST t sp1 SpAbsent k - | Just zero1 <- cheapZero (applySparse sp1 t) = - k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) - | otherwise = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) - --- double sparse yields sparse -sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (emaybe (evar IZ) - (ENothing ext (applySparse sp3 (fromSMTy t))) - (EJust ext (inj2 (evar IZ)))) - (emaybe (evar (IS IZ)) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ))))) - --- single sparse can yield non-sparse if the other argument is always present -sparsePlusS SF _ t (SpSparse sp1) sp2 k = - sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> - k sp3 Noinj (Inj inj2) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (inj2 (evar IZ)) - (plus (evar IZ) (evar (IS IZ)))) -sparsePlusS ST _ t (SpSparse sp1) sp2 k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> EJust ext (inj2 b)) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (EJust ext (inj2 (evar IZ))) - (EJust ext (plus (evar IZ) (evar (IS IZ))))) -sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = - sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> - k sp3 inj2 inj1 (flip plus) - --- products -sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = - sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> - sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> - k (SpPair sp3a sp3b) - (withInj2 minj13a minj13b $ \inj13a inj13b -> - \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) - (withInj2 minj23a minj23b $ \inj23a inj23b -> - \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) - (\x1 x2 -> - eunPair x1 $ \w1 x1a x1b -> - eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> - EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) - --- coproducts -sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = - sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> - sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> - let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) - inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) - inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) - in - k (SpLEither sp3a sp3b) - (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) - (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) - (\x1 x2 -> - elet x2 $ - elcase (weakenExpr WSink x1) - (elcase (evar IZ) - nil - (inl (inj23a (evar IZ))) - (inr (inj23b (evar IZ)))) - (elcase (evar (IS IZ)) - (inl (inj13a (evar IZ))) - (inl (plusa (evar (IS IZ)) (evar IZ))) - (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) - (elcase (evar (IS IZ)) - (inr (inj13b (evar IZ))) - (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") - (inr (plusb (evar (IS IZ)) (evar IZ))))) - --- maybe -sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpMaybe sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (emaybe (evar IZ) - (ENothing ext (applySparse sp3 (fromSMTy t))) - (EJust ext (inj2 (evar IZ)))) - (emaybe (evar (IS IZ)) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ))))) - --- dense array cotangents simply recurse -sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = - sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> - k (SpArr sp3) - (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) - (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) - (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) - (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) - --- scalars -sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) |
