diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 75 | ||||
| -rw-r--r-- | src/AST/Bindings.hs | 2 | ||||
| -rw-r--r-- | src/AST/Count.hs | 9 | ||||
| -rw-r--r-- | src/AST/Env.hs | 74 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 21 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 290 | ||||
| -rw-r--r-- | src/AST/Sparse/Types.hs | 107 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 3 | ||||
| -rw-r--r-- | src/AST/Types.hs | 2 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 113 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 2 | 
11 files changed, 632 insertions, 66 deletions
| diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 03369c8..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,14 +1,13 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE UndecidableInstances #-}  module AST.Accum where  import AST.Types -import CHAD.Types  import Data @@ -35,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where    -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)  deriving instance Show (SAcPrj p a b) -type family AcIdx p t where -  AcIdx APHere t = TNil -  AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) -  AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) -  AcIdx (APLeft p) (TLEither a b) = AcIdx p a -  AcIdx (APRight p) (TLEither a b) = AcIdx p b -  AcIdx (APJust p) (TMaybe a) = AcIdx p a -  AcIdx (APArrIdx p) (TArr n a) = -    -- ((index, shapes info), recursive info) +type data AIDense = AID | AIS + +data SAIDense d where +  SAID :: SAIDense AID +  SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where +  AcIdx d APHere t = TNil +  AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a +  AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b +  AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) +  AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) +  AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a +  AcIdx d (APRight p) (TLEither a b) = AcIdx d p b +  AcIdx d (APJust p) (TMaybe a) = AcIdx d p a +  AcIdx AID (APArrIdx p) (TArr n a) = +    -- (index, recursive info) +    TPair (Tup (Replicate n TIx)) (AcIdx AID p a) +  AcIdx AIS (APArrIdx p) (TArr n a) = +    -- ((index, shape info), recursive info)      TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) -          (AcIdx p a) -  -- AcIdx (APArrSlice m) (TArr n a) = +          (AcIdx AIS p a) +  -- AcIdx AID (APArrSlice m) (TArr n a) = +  --   -- index +  --   Tup (Replicate m TIx) +  -- AcIdx AIS (APArrSlice m) (TArr n a) =    --   -- (index, array shape)    --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t +  acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b  acPrjTy SAPHere t = t  acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t @@ -75,19 +92,23 @@ tZeroInfo (SMTMaybe _) = STNil  tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)  tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where +  DeepZeroInfo TNil = TNil +  DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) +  DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) +  DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) +  DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) +  DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil  -- -- | Additional info needed for accumulation. This is empty unless there is  -- -- sparsity in the monoid. diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 745a93b..2310f4b 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -69,7 +69,7 @@ collectBindings = \env -> fst . go env WId    where      go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)      go _ _ SETop = (BTop, WId) -    go (ty `SCons` env) w (SEYes sub) = +    go (ty `SCons` env) w (SEYesR sub) =        let (bs, w') = go env (WPop w) sub        in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')      go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 0c682c6..ca4d7ab 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -134,8 +134,9 @@ occCountGeneral onehot unpush alter many = go WId        ECustom _ _ _ _ _ _ _ a b -> re a <> re b        ERecompute _ e -> re e        EWith _ _ a b -> re a <> re1 b -      EAccum _ _ _ a b e -> re a <> re b <> re e +      EAccum _ _ _ a _ b e -> re a <> re b <> re e        EZero _ _ e -> re e +      EDeepZero _ _ e -> re e        EPlus _ _ a b -> re a <> re b        EOneHot _ _ _ a b -> re a <> re b        EError{} -> mempty @@ -154,7 +155,7 @@ deleteUnused (_ `SCons` env) OccEnd k =  deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =    deleteUnused env occenv $ \sub ->      case count of Zero -> k (SENo sub) -                  _    -> k (SEYes sub) +                  _    -> k (SEYesR sub)  unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t  unsafeWeakenWithSubenv = \sub -> @@ -163,7 +164,7 @@ unsafeWeakenWithSubenv = \sub ->                       Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")    where      sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) -    sinkViaSubenv IZ (SEYes _) = Just IZ +    sinkViaSubenv IZ (SEYesR _) = Just IZ      sinkViaSubenv IZ (SENo _) = Nothing -    sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub +    sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub      sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs index 4f34166..422f0f7 100644 --- a/src/AST/Env.hs +++ b/src/AST/Env.hs @@ -1,59 +1,85 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE TypeOperators #-}  module AST.Env where +import Data.Type.Equality + +import AST.Sparse  import AST.Weaken +import CHAD.Types  import Data  -- | @env'@ is a subset of @env@: each element of @env@ is either included in  -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where -  SETop :: Subenv '[] '[] -  SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') -  SENo  :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') +data Subenv' s env env' where +  SETop :: Subenv' s '[] '[] +  SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') +  SENo  :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () +               => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') +               => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s -subList :: SList f env -> Subenv env env' -> SList f env' +{-# COMPLETE SETop, SEYesR, SENo #-} + +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'  subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)  subList (SCons _ xs) (SENo sub) = subList xs sub -subenvAll :: SList f env -> Subenv env env +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env  subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) -subenvNone :: SList f env -> Subenv env '[] +subenvNone :: SList f env -> Subenv' s env '[]  subenvNone SNil = SETop  subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3  subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)  subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')  subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)  subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0  sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub  sinkWithSubenv (SENo sub) = sinkWithSubenv sub -wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env  wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)  wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 41da656..fef9686 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -25,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO)  import AST  import AST.Count +import AST.Sparse.Types  import CHAD.Types  import Data @@ -304,18 +305,24 @@ ppExpr' d val expr = case expr of             <> hardline <> e2')          (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) -  EAccum _ t prj e1 e2 e3 -> do +  EAccum _ t prj e1 sp e2 e3 -> do      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2      e3' <- ppExpr' 11 val e3      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3'] +      ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) +            [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']    EZero _ t e1 -> do      e1' <- ppExpr' 11 val e1      return $ ppParen (d > 0) $        annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' +  EDeepZero _ t e1 -> do +    e1' <- ppExpr' 11 val e1 +    return $ ppParen (d > 0) $ +      annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' +    EPlus _ t a b -> do      a' <- ppExpr' 11 val a      b' <- ppExpr' 11 val b @@ -368,6 +375,16 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"  ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj  ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." +  ppX :: PrettyX x => Expr x env t -> ADoc  ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs new file mode 100644 index 0000000..93258b7 --- /dev/null +++ b/src/AST/Sparse.hs @@ -0,0 +1,290 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE RankNTypes #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where + +import Data.Type.Equality + +import AST +import AST.Sparse.Types +import Data (SBool(..)) + + +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2  -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = +  eunPair e1 $ \w1 e1a e1b -> +  eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> +    EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) +              (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = +  elet e2 $ +    elcase (weakenExpr WSink e1) +      (evar IZ) +      (elcase (evar (IS IZ)) +        (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) +        (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) +        (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) +      (elcase (evar (IS IZ)) +        (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) +        (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") +        (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = +  elet e2 $ +    emaybe (weakenExpr WSink e1) +      (evar IZ) +      (emaybe (evar (IS IZ)) +        (EJust ext (evar IZ)) +        (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + + +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) +  | Just e1 <- cheapZero t1 +  , Just e2 <- cheapZero t2 +  = Just (EPair ext e1 e2) +  | otherwise +  = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of +  STI32 -> Just (EConst ext t 0) +  STI64 -> Just (EConst ext t 0) +  STF32 -> Just (EConst ext t 0.0) +  STF64 -> Just (EConst ext t 0.0) + + +data Injection sp a b where +  -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that +  -- 'sparsePlusS' can provide injections even if the caller doesn't require +  -- them. This simplifies the sparsePlusS code. +  Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b +  Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 +         -> ((forall e. Ex e a1 -> Ex e b1) +             -> (forall e. Ex e a2 -> Ex e b2) +             -> (forall e'. Ex e' a' -> Ex e' b')) +         -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. TODO can this be improved? +sparsePlusS +  :: SBool inj1 -> SBool inj2 +  -> SMTy t -> Sparse t t1 -> Sparse t t2 +  -> (forall t3. Sparse t t3 +              -> Injection inj1 t1 t3  -- only available if first injection is requested (second argument may be absent) +              -> Injection inj2 t2 t3  -- only available if second injection is requested (first argument may be absent) +              -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) +              -> r) +  -> r +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = +  k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = +  sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = +  sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> +    k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = +  let ta = applySparse sp1 (fromSMTy t) in +  sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 +      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) +      minj2 +      (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = +  let tb = applySparse sp2 (fromSMTy t) in +  sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> +    k sp3 +      minj1 +      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +      (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = +  let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in +  sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 +      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) +      minj2 +      (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = +  let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in +  sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> +    k sp3 +      minj1 +      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) +      (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = +  let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in +  sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> +    k sp3 +      (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) +      minj2 +      (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = +  let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in +  sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> +    k sp3 +      minj1 +      (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) +      (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t sp1 sp2 k +  | Just Refl <- isDense t sp1 +  , Just Refl <- isDense t sp2 +  = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k +  | Just zero2 <- cheapZero (applySparse sp2 t) = +      k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) +  | otherwise = +      k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k +  | Just zero1 <- cheapZero (applySparse sp1 t) = +      k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) +  | otherwise = +      k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpSparse sp3) +      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) +      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (emaybe (evar IZ) +              (ENothing ext (applySparse sp3 (fromSMTy t))) +              (EJust ext (inj2 (evar IZ)))) +            (emaybe (evar (IS IZ)) +              (EJust ext (inj1 (evar IZ))) +              (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = +  sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> +    k sp3 Noinj (Inj inj2) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (inj2 (evar IZ)) +            (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpSparse sp3) +      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) +      (Inj $ \b -> EJust ext (inj2 b)) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (EJust ext (inj2 (evar IZ))) +            (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = +  sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> +    k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = +  sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> +  sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> +    k (SpPair sp3a sp3b) +      (withInj2 minj13a minj13b $ \inj13a inj13b -> +        \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) +      (withInj2 minj23a minj23b $ \inj23a inj23b -> +        \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) +      (\x1 x2 -> +        eunPair x1 $ \w1 x1a x1b -> +        eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> +          EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = +  sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> +  sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> +    let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) +        inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) +        inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) +    in +    k (SpLEither sp3a sp3b) +      (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) +      (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) +      (\x1 x2 -> +        elet x2 $ +          elcase (weakenExpr WSink x1) +            (elcase (evar IZ) +              nil +              (inl (inj23a (evar IZ))) +              (inr (inj23b (evar IZ)))) +            (elcase (evar (IS IZ)) +              (inl (inj13a (evar IZ))) +              (inl (plusa (evar (IS IZ)) (evar IZ))) +              (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) +            (elcase (evar (IS IZ)) +              (inr (inj13b (evar IZ))) +              (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") +              (inr (plusb (evar (IS IZ)) (evar IZ))))) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = +  sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> +    k (SpMaybe sp3) +      (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) +      (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) +      (\a b -> +        elet b $ +          emaybe (weakenExpr WSink a) +            (emaybe (evar IZ) +              (ENothing ext (applySparse sp3 (fromSMTy t))) +              (EJust ext (inj2 (evar IZ)))) +            (emaybe (evar (IS IZ)) +              (EJust ext (inj1 (evar IZ))) +              (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = +  sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> +    k (SpArr sp3) +      (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) +      (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) +      (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) +                      (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs new file mode 100644 index 0000000..10cac4e --- /dev/null +++ b/src/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AST.Sparse.Types where + +import AST.Types + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + + +data Sparse t t' where +  SpSparse :: Sparse t t' -> Sparse t (TMaybe t') +  SpAbsent :: Sparse t TNil + +  SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') +  SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') +  SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') +  SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') +  SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where +  applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where +  applySparse (SpSparse s) t = STMaybe (applySparse s t) +  applySparse SpAbsent _ = STNil +  applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) +  applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) +  applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) +  applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) +  applySparse SpScal t = t + +instance ApplySparse SMTy where +  applySparse (SpSparse s) t = SMTMaybe (applySparse s t) +  applySparse SpAbsent _ = SMTNil +  applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) +  applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) +  applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) +  applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) +  applySparse SpScal t = t + + +class IsSubType s where +  type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint +  subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' +  subtTrans :: s a b -> s b c -> s a c +  subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where +  type IsSubTypeSubject (:~:) f = () +  subtApply = gcastWith +  subtTrans = trans +  subtFull _ = Refl + +instance IsSubType Sparse where +  type IsSubTypeSubject Sparse f = f ~ SMTy +  subtApply = applySparse + +  subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) +  subtTrans _ SpAbsent = SpAbsent +  subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) +  subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) +  subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) +  subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) +  subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) +  subtTrans SpScal SpScal = SpScal + +  subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) +  | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl +  | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) +  | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl +  | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) +  | Just Refl <- isDense t s = Just Refl +  | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) +  | Just Refl <- isDense t s = Just Refl +  | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 3c353d4..dcaf82f 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -63,8 +63,9 @@ splitLets' = \sub -> \case    ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)    ERecompute x e -> ERecompute x (splitLets' sub e)    EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) -  EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) +  EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)    EZero x t ezi -> EZero x t (splitLets' sub ezi) +  EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)    EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)    EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)    EError x t s -> EError x t s diff --git a/src/AST/Types.hs b/src/AST/Types.hs index a3b7302..42bfb92 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -5,9 +5,9 @@  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-}  module AST.Types where  import Data.Int (Int32, Int64) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index f5841e0..48dd709 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -1,18 +1,22 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where +module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where  import AST +import AST.Sparse.Types  import Data --- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them --- into their concrete implementations. +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity.  unMonoid :: Ex env t -> Ex env t  unMonoid = \case    EZero _ t e -> zero t e +  EDeepZero _ t e -> deepZero t e    EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)    EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -49,7 +53,10 @@ unMonoid = \case    ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)    ERecompute _ e -> ERecompute ext (unMonoid e)    EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) -  EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) +  EAccum _ t p eidx sp eval eacc -> +    accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> +    acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> +      EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))    EError _ t s -> EError ext t s  zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t @@ -67,6 +74,27 @@ zero (SMTScal t) _ = case t of    STF32 -> EConst ext STF32 0.0    STF64 -> EConst ext STF64 0.0 +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = +  ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) +                         (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = +  elcase e +    (ELNil ext (fromSMTy t1) (fromSMTy t2)) +    (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) +    (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = +  emaybe e +    (ENothing ext (fromSMTy t)) +    (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of +  STI32 -> EConst ext STI32 0 +  STI64 -> EConst ext STI64 0 +  STF32 -> EConst ext STF32 0.0 +  STF64 -> EConst ext STF64 0.0 +  plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t  -- don't destroy the effects!  plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext @@ -107,7 +135,7 @@ plus (SMTArr _ t) a b =             a b  plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t  onehot typ topprj idx arg = case (typ, topprj) of    (_, SAPHere) ->      ELet ext arg $ @@ -145,3 +173,78 @@ onehot typ topprj idx arg = case (typ, topprj) of               (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))               (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $                  zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse +  :: SMTy t -> Sparse t t' -> Ex env t' +  -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) +  -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of +  (_, s) | Just Refl <- isDense topty s -> +    accum WId SAPHere (ENil ext) arg +  (SMTScal _, SpScal) -> +    accum WId SAPHere (ENil ext) arg  -- should be handled by isDense already, but meh +  (_, SpSparse s) -> +    emaybe arg +      (ENil ext) +      (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) +  (_, SpAbsent) -> +    ENil ext +  (SMTPair t1 t2, SpPair s1 s2) -> +    eunPair arg $ \w1 e1 e2 -> +      elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ +        accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) +  (SMTLEither t1 t2, SpLEither s1 s2) -> +    elcase arg +      (ENil ext) +      (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) +      (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) +  (SMTMaybe t, SpMaybe s) -> +    emaybe arg +      (ENil ext) +      (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) +  (SMTArr n t, SpArr s) -> +    let tn = tTup (sreplicate n tIx) in +    elet arg $ +    elet (EBuild ext n (EShape ext (evar IZ)) $ +            accumulateSparse t s +              (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) +              (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ +      ENil ext + +acPrjCompose +  :: SAIDense dense +  -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) +  -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) +  -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = +  acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> +    k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = +  acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> +    k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k +  | Dict <- styKnown (typeOf idx1) = +  acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> +    k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k +  | Dict <- styKnown (typeOf idx1) = +  acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> +    k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = +  acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> +    k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = +  acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> +    k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = +  acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> +    k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k +  | Dict <- styKnown (typeOf idx1) = +  acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> +    k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k +  | Dict <- styKnown (typeOf idx1) = +  acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> +    k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 6752c24..c6efe37 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where    SSegNil :: SSegments '[]    SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where    fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil  auto :: KnownListSpine list => SList (Const ()) list | 
