aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Sparse.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Sparse.hs')
-rw-r--r--src/AST/Sparse.hs287
1 files changed, 0 insertions, 287 deletions
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
deleted file mode 100644
index 2a29799..0000000
--- a/src/AST/Sparse.hs
+++ /dev/null
@@ -1,287 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImpredicativeTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE RankNTypes #-}
-
-{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
-module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
-
-import Data.Type.Equality
-
-import AST
-import AST.Sparse.Types
-import Data (SBool(..))
-
-
-sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
-sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
-sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
-sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
-sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
- eunPair e1 $ \w1 e1a e1b ->
- eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
- EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
- (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
-sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
- elet e2 $
- elcase (weakenExpr WSink e1)
- (evar IZ)
- (elcase (evar (IS IZ))
- (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
- (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
- (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
- (elcase (evar (IS IZ))
- (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
- (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
- (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
-sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
- elet e2 $
- emaybe (weakenExpr WSink e1)
- (evar IZ)
- (emaybe (evar (IS IZ))
- (EJust ext (evar IZ))
- (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
-sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
-sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
-
-
-cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
-cheapZero SMTNil = Just (ENil ext)
-cheapZero (SMTPair t1 t2)
- | Just e1 <- cheapZero t1
- , Just e2 <- cheapZero t2
- = Just (EPair ext e1 e2)
- | otherwise
- = Nothing
-cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
-cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
-cheapZero SMTArr{} = Nothing
-cheapZero (SMTScal t) = case t of
- STI32 -> Just (EConst ext t 0)
- STI64 -> Just (EConst ext t 0)
- STF32 -> Just (EConst ext t 0.0)
- STF64 -> Just (EConst ext t 0.0)
-
-
-data Injection sp a b where
- -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
- -- 'sparsePlusS' can provide injections even if the caller doesn't require
- -- them. This simplifies the sparsePlusS code.
- Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
- Noinj :: Injection False a b
-
-withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
-withInj (Inj f) k = Inj (k f)
-withInj Noinj _ = Noinj
-
-withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
- -> ((forall e. Ex e a1 -> Ex e b1)
- -> (forall e. Ex e a2 -> Ex e b2)
- -> (forall e'. Ex e' a' -> Ex e' b'))
- -> Injection sp a' b'
-withInj2 (Inj f) (Inj g) k = Inj (k f g)
-withInj2 Noinj _ _ = Noinj
-withInj2 _ Noinj _ = Noinj
-
--- | This function produces quadratically-sized code in the presence of nested
--- dynamic sparsity. TODO can this be improved?
-sparsePlusS
- :: SBool inj1 -> SBool inj2
- -> SMTy t -> Sparse t t1 -> Sparse t t2
- -> (forall t3. Sparse t t3
- -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
- -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
- -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
- -> r)
- -> r
--- nil override (but don't destroy effects!)
-sparsePlusS _ _ SMTNil _ _ k =
- k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
-
--- simplifications
-sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
- sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
- k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
-sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
- sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
- k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
-
-sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
- let ta = applySparse sp1 (fromSMTy t) in
- sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
- k sp3
- (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
- minj2
- (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
-sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
- let tb = applySparse sp2 (fromSMTy t) in
- sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
- k sp3
- minj1
- (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
- (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
-
-sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
- let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
- sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
- k sp3
- (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
- minj2
- (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
-sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
- let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
- sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
- k sp3
- minj1
- (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
- (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
-
-sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
- let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
- sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
- k sp3
- (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
- minj2
- (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
-sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
- let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
- sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
- k sp3
- minj1
- (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
- (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
-sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
-sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
-
--- TODO: sparse of Just is just Maybe
-
--- dense plus
-sparsePlusS _ _ t sp1 sp2 k
- | Just Refl <- isDense t sp1
- , Just Refl <- isDense t sp2
- = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
-
--- handle absents
-sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
-sparsePlusS ST _ t SpAbsent sp2 k
- | Just zero2 <- cheapZero (applySparse sp2 t) =
- k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
- | otherwise =
- k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
-
-sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
-sparsePlusS _ ST t sp1 SpAbsent k
- | Just zero1 <- cheapZero (applySparse sp1 t) =
- k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
- | otherwise =
- k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
-
--- double sparse yields sparse
-sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpSparse sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (emaybe (evar IZ)
- (ENothing ext (applySparse sp3 (fromSMTy t)))
- (EJust ext (inj2 (evar IZ))))
- (emaybe (evar (IS IZ))
- (EJust ext (inj1 (evar IZ)))
- (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-
--- single sparse can yield non-sparse if the other argument is always present
-sparsePlusS SF _ t (SpSparse sp1) sp2 k =
- sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
- k sp3 Noinj (Inj inj2)
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (inj2 (evar IZ))
- (plus (evar IZ) (evar (IS IZ))))
-sparsePlusS ST _ t (SpSparse sp1) sp2 k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpSparse sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> EJust ext (inj2 b))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (EJust ext (inj2 (evar IZ)))
- (EJust ext (plus (evar IZ) (evar (IS IZ)))))
-sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
- sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
- k sp3 inj2 inj1 (flip plus)
-
--- products
-sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
- sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
- sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
- k (SpPair sp3a sp3b)
- (withInj2 minj13a minj13b $ \inj13a inj13b ->
- \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
- (withInj2 minj23a minj23b $ \inj23a inj23b ->
- \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
- (\x1 x2 ->
- eunPair x1 $ \w1 x1a x1b ->
- eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
- EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
-
--- coproducts
-sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
- sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
- sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
- let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
- inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
- inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
- in
- k (SpLEither sp3a sp3b)
- (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
- (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
- (\x1 x2 ->
- elet x2 $
- elcase (weakenExpr WSink x1)
- (elcase (evar IZ)
- nil
- (inl (inj23a (evar IZ)))
- (inr (inj23b (evar IZ))))
- (elcase (evar (IS IZ))
- (inl (inj13a (evar IZ)))
- (inl (plusa (evar (IS IZ)) (evar IZ)))
- (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
- (elcase (evar (IS IZ))
- (inr (inj13b (evar IZ)))
- (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
- (inr (plusb (evar (IS IZ)) (evar IZ)))))
-
--- maybe
-sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
- sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
- k (SpMaybe sp3)
- (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
- (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
- (\a b ->
- elet b $
- emaybe (weakenExpr WSink a)
- (emaybe (evar IZ)
- (ENothing ext (applySparse sp3 (fromSMTy t)))
- (EJust ext (inj2 (evar IZ))))
- (emaybe (evar (IS IZ))
- (EJust ext (inj1 (evar IZ)))
- (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
-
--- dense array cotangents simply recurse
-sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
- sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
- k (SpArr sp3)
- (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
- (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
- (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
- (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
-
--- scalars
-sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))