From d1b2e2c3a3cdaf49ff5e4bae6fe9b0612c3779c2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 18 Jun 2025 00:00:11 +0200 Subject: Tests pass, should check if output is sensible --- chad-fast.cabal | 2 + src/AST.hs | 20 ++- src/AST/Accum.hs | 58 +++++---- src/AST/Count.hs | 2 +- src/AST/Pretty.hs | 21 +++- src/AST/Sparse.hs | 110 +---------------- src/AST/Sparse/Types.hs | 107 ++++++++++++++++ src/AST/SplitLets.hs | 2 +- src/AST/UnMonoid.hs | 111 ++++++++++++++++- src/Analysis/Identity.hs | 4 +- src/CHAD.hs | 117 +----------------- src/CHAD/Accum.hs | 45 +++++++ src/CHAD/Top.hs | 54 ++++----- src/CHAD/Types.hs | 16 +++ src/Compile.hs | 171 ++++++-------------------- src/Data.hs | 8 +- src/Example.hs | 3 +- src/Interpreter.hs | 151 ++++++++++------------- src/Language.hs | 6 +- src/Language/AST.hs | 5 +- src/Simplify.hs | 309 +++++++++++++++++++++++++++++++---------------- test/Main.hs | 29 +++-- 22 files changed, 726 insertions(+), 625 deletions(-) create mode 100644 src/AST/Sparse/Types.hs create mode 100644 src/CHAD/Accum.hs diff --git a/chad-fast.cabal b/chad-fast.cabal index b8510d2..b7270e4 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -19,12 +19,14 @@ library AST.Env AST.Pretty AST.Sparse + AST.Sparse.Types AST.SplitLets AST.Types AST.UnMonoid AST.Weaken AST.Weaken.Auto CHAD + CHAD.Accum CHAD.EnvDescr CHAD.Top CHAD.Types diff --git a/src/AST.hs b/src/AST.hs index c24e3e7..5aab4fc 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -25,6 +25,7 @@ import Data.Kind (Type) import Array import AST.Accum +import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types @@ -91,11 +92,16 @@ data Expr x env t where ERecompute :: x t -> Expr x env t -> Expr x env t -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t @@ -218,9 +224,10 @@ typeOf = \case ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil + EAccum _ _ _ _ _ _ _ -> STNil EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t EPlus _ t _ _ -> fromSMTy t EOneHot _ t _ _ _ -> fromSMTy t @@ -261,8 +268,9 @@ extOf = \case ECustom x _ _ _ _ _ _ _ _ -> x ERecompute x _ -> x EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x EZero x _ _ -> x + EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x @@ -306,8 +314,9 @@ travExt f = \case ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 ERecompute x e -> ERecompute <$> f x <*> travExt f e EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 - EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b EError x t s -> EError <$> f x <*> pure t <*> pure s @@ -364,8 +373,9 @@ subst' f w = \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 158b4d9..619c2b1 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} @@ -33,35 +34,38 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type data StillDense = AI_D | AI_S -data SStillDense dense where - SAI_D :: SStillDense AI_D - SAI_S :: SStillDense AI_S -deriving instance Show (SStillDense dense) +type data AIDense = AID | AIS -type family AcIdx dense p t where - AcIdx dense APHere t = TNil - AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a - AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b - AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b) - AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b) - AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a - AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b - AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a - AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a) - AcIdx AI_S (APArrIdx p) (TArr n a) = - -- ((index, shapes info), recursive info) +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx AI_S p a) - -- AcIdx AI_D (APArrSlice m) (TArr n a) = + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = -- -- index -- Tup (Replicate m TIx) - -- AcIdx AI_S (APArrSlice m) (TArr n a) = + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -type AcIdxD p t = AcIdx AI_D p t -type AcIdxS p t = AcIdx AI_S p t +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t @@ -88,6 +92,16 @@ tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 03a36f6..05be524 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -134,7 +134,7 @@ occCountGeneral onehot unpush alter many = go WId ECustom _ _ _ _ _ _ _ a b -> re a <> re b ERecompute _ e -> re e EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e + EAccum _ _ _ a _ b e -> re a <> re b <> re e EZero _ _ e -> re e EPlus _ _ a b -> re a <> re b EOneHot _ _ _ a b -> re a <> re b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 41da656..fef9686 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -25,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO) import AST import AST.Count +import AST.Sparse.Types import CHAD.Types import Data @@ -304,18 +305,24 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] EZero _ t e1 -> do e1' <- ppExpr' 11 val e1 return $ ppParen (d > 0) $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b @@ -368,6 +375,16 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 369d395..f0a1f2a 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -1,116 +1,19 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse where +module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where -import Data.Kind (Constraint, Type) import Data.Type.Equality import AST +import AST.Sparse.Types +import Data (SBool(..)) -data Sparse t t' where - SpSparse :: Sparse t t' -> Sparse t (TMaybe t') - SpAbsent :: Sparse t TNil - - SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') - SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') - SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') - SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') - SpScal :: Sparse (TScal t) (TScal t) -deriving instance Show (Sparse t t') - -class ApplySparse f where - applySparse :: Sparse t t' -> f t -> f t' - -instance ApplySparse STy where - applySparse (SpSparse s) t = STMaybe (applySparse s t) - applySparse SpAbsent _ = STNil - applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) - applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) - applySparse SpScal t = t - -instance ApplySparse SMTy where - applySparse (SpSparse s) t = SMTMaybe (applySparse s t) - applySparse SpAbsent _ = SMTNil - applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) - applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) - applySparse SpScal t = t - - -class IsSubType s where - type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint - subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c - subtFull :: IsSubTypeSubject s f => f t -> s t t - -instance IsSubType (:~:) where - type IsSubTypeSubject (:~:) f = () - subtApply = gcastWith - subtTrans = trans - subtFull _ = Refl - -instance IsSubType Sparse where - type IsSubTypeSubject Sparse f = f ~ SMTy - subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - - subtFull = spDense - -spDense :: SMTy t -> Sparse t t -spDense SMTNil = SpAbsent -spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) -spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) -spDense (SMTMaybe t) = SpMaybe (spDense t) -spDense (SMTArr _ t) = SpArr (spDense t) -spDense (SMTScal _) = SpScal - -isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') -isDense SMTNil SpAbsent = Just Refl -isDense _ SpSparse{} = Nothing -isDense _ SpAbsent = Nothing -isDense (SMTPair t1 t2) (SpPair s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTLEither t1 t2) (SpLEither s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTMaybe t) (SpMaybe s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTArr _ t) (SpArr s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTScal _) SpScal = Just Refl - -isAbsent :: Sparse t t' -> Bool -isAbsent (SpSparse s) = isAbsent s -isAbsent SpAbsent = True -isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpMaybe s) = isAbsent s -isAbsent (SpArr s) = isAbsent s -isAbsent SpScal = False - sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' sparsePlus _ SpAbsent _ _ = ENil ext sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 @@ -143,11 +46,6 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 -data SBool b where - SF :: SBool False - ST :: SBool True -deriving instance Show (SBool b) - data Injection sp a b where -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that -- 'sparsePlusS' can provide injections even if the caller doesn't require diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs new file mode 100644 index 0000000..10cac4e --- /dev/null +++ b/src/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AST.Sparse.Types where + +import AST.Types + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 3c353d4..2dad17a 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -63,7 +63,7 @@ splitLets' = \sub -> \case ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) EZero x t ezi -> EZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 389dd5a..d498aaa 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -1,18 +1,22 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where +module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where import AST +import AST.Sparse.Types import Data --- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them --- into their concrete implementations. +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) @@ -49,7 +53,10 @@ unMonoid = \case ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) ERecompute _ e -> ERecompute ext (unMonoid e) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t @@ -66,6 +73,27 @@ zero (SMTScal t) _ = case t of STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil _ = ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t plus SMTNil _ _ = ENil ext plus (SMTPair t1 t2) a b = @@ -143,3 +171,78 @@ onehot typ topprj idx arg = case (typ, topprj) of (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 4501c32..2fd321d 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -307,11 +307,11 @@ idana env expr = case expr of let res = VIPair v2 x2 pure (res, EWith res t e1' e2') - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' e2' e3') + pure (VINil, EAccum VINil t prj e1' sp e2' e3') EZero _ t e1 -> do -- Approximate the result of EZero to be independent from the zero info diff --git a/src/CHAD.hs b/src/CHAD.hs index 3dedec3..621aa3e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -34,7 +34,6 @@ module CHAD ( import Data.Functor.Const import Data.Some -import Data.Type.Bool (If) import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) @@ -45,6 +44,7 @@ import AST.Count import AST.Env import AST.Sparse import AST.Weaken.Auto +import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -348,28 +348,8 @@ opt2UnSparse = go . opt2 go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" ------------------------------------- MONOIDS ----------------------------------- - -d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) -d2zeroInfo STNil _ = ENil ext -d2zeroInfo (STPair a b) e = - eunPair e $ \_ e1 e2 -> - EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) -d2zeroInfo STEither{} _ = ENil ext -d2zeroInfo STLEither{} _ = ENil ext -d2zeroInfo STMaybe{} _ = ENil ext -d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e -d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext -d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" - - ----------------------------------- SPARSITY ----------------------------------- -subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') -subenvD1E SETop = SETop -subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) -subenvD1E (SENo sub) = SENo (subenvD1E sub) - expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e expandSparse t (SpSparse sp) epr e = @@ -499,23 +479,6 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- -makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators _ SNil e = e -makeAccumulators w (t `SCons` envpro) e = - makeAccumulators (WPop w) envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - fromArrayValId :: Maybe (ValId t) -> Maybe Int fromArrayValId (Just (VIArr i _)) = Just i fromArrayValId _ = Nothing @@ -788,8 +751,7 @@ drev des accumMap sd = \case (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) - in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx -> - EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI))) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop @@ -1275,6 +1237,7 @@ drev des accumMap sd = \case EWith{} -> err_accum EZero{} -> err_monoid + EDeepZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid @@ -1392,76 +1355,6 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e - | Refl <- chadD1Id (typeOf e) - , Refl <- chadD1EId (descrList des) + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) = mapExt (const ext) e - where - chadD1Id :: STy a -> D1 a :~: a - chadD1Id STNil = Refl - chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl - chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl - chadD1Id (STScal _) = Refl - chadD1Id STAccum{} = error "accumulators not allowed in source program" - - chadD1EId :: SList STy l -> D1E l :~: l - chadD1EId SNil = Refl - chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl - -accumulateSparse - :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t' - -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil) - -> Ex env TNil -accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of - (_, _, s) | Just Refl <- isDense topty s -> - accum WId SAPHere arg (ENil ext) - (_, SMTScal _, SpScal) -> - accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh - (_, _, SpSparse s) -> - emaybe arg - (ENil ext) - (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w))) - (_, _, SpAbsent) -> - ENil ext - (SAI_D, SMTPair t1 t2, SpPair s1 s2) -> - eunPair arg $ \w1 e1 e2 -> - elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ - accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) - (SAI_S, SMTPair{}, SpPair{}) -> - error "TODO: accumulating into pair inside coproduct unimplemented" - -- There are two different ways this can be accomplished: - -- 1. Ensure we have the requisite ZeroInfo here. This means that an - -- accum-mode variable reference will (if its incoming cotangent is - -- sparse enough) need to store some ZeroInfo fragments computed from - -- the primal (not necessarily the entire primal). Doing this properly, - -- i.e. not just storing a full D1 but only the required ZeroInfo - -- fragments, is possible and not too inefficient but a bit of - -- engineering again. - -- 2. When creating an accumulator, don't initialise it with a generic - -- EZero based on a ZeroInfo, but instead a special "deep zero" based on - -- probably a full D1. This deep zero also initialises Left/Right/Just - -- modelled after the primal. With this, an accumulation needs no zero - -- info whatsoever (!) under the assumption that it receives a cotangent - -- that is compatible with the primal it is propagated back to. - (_, SMTLEither t1 t2, SpLEither s1 s2) -> - elcase arg - (ENil ext) - (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) - (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) - (_, SMTMaybe t, SpMaybe s) -> - emaybe arg - (ENil ext) - (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) - (SAI_D, SMTArr n t, SpArr s) -> - let tn = tTup (sreplicate n tIx) in - elet arg $ - elet (EBuild ext n (EShape ext (evar IZ)) $ - accumulateSparse dense t s - (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) - (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $ - ENil ext - (SAI_S, SMTArr{}, SpArr{}) -> - error "TODO: accumulating into array inside coproduct unimplemented" - -- See the pair case above, same reasoning diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs new file mode 100644 index 0000000..8c7794a --- /dev/null +++ b/src/CHAD/Accum.hs @@ -0,0 +1,45 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD and CHAD.Top. +module CHAD.Accum where + +import AST +import CHAD.Types +import Data +import AST.Env + + +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 130174a..484779e 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -12,9 +12,12 @@ module CHAD.Top where import Analysis.Identity import AST +import AST.Env +import AST.Sparse import AST.SplitLets import AST.Weaken.Auto import CHAD +import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -43,36 +46,22 @@ accumDescr (t `SCons` env) k = accumDescr env $ \des -> if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) else k (des `DPush` (t, Nothing, SMerge)) -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl - reassembleD2E :: Descr env sto + -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, _, SAccum)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) - (ESnd ext (EVar ext (typeOf e) IZ)))) - (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, _, SMerge)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) - (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) - (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) @@ -82,21 +71,22 @@ chad config env (term :: Ex env t) let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (select SAccum descr) $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr VarMap.empty term')) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term') + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') where term' = identityAnalysis env (splitLets term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 83f013d..8b3a8db 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Types where @@ -124,3 +125,18 @@ lemZeroInfoScal STI64 = Refl lemZeroInfoScal STF32 = Refl lemZeroInfoScal STF64 = Refl lemZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/Compile.hs b/src/Compile.hs index 722b432..a5c4fb7 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -45,6 +45,7 @@ import qualified Prelude import Array import AST import AST.Pretty (ppSTy, ppExpr) +import AST.Sparse.Types (isDense) import Compile.Exec import Data import Interpreter.Rep @@ -1002,95 +1003,7 @@ compile' env = \case rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - EAccum _ t prj eidx eval eacc -> do - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a - -- full zero array with the given zero info (for the type SMTArr n t1). - initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () - initZeroArray n t1 v vzi = do - shszname <- genName' "inacshsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi) - newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi) - emit $ SAsg v (CELit newarrName) - forM_ (initZeroFromMemset t1) $ \f1 -> do - ivar <- genName' "i" - ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts - - -- If something needs to be done to properly initialise this type to - -- zero after memory has already been initialised to all-zero bytes, - -- returns an action that does so. - -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ()) - initZeroFromMemset SMTNil = Nothing - initZeroFromMemset (SMTPair t1 t2) = - case (initZeroFromMemset t1, initZeroFromMemset t2) of - (Nothing, Nothing) -> Nothing - (mf1, mf2) -> Just $ \v vzi -> do - forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a") - forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b") - initZeroFromMemset SMTLEither{} = Nothing - initZeroFromMemset SMTMaybe{} = Nothing - initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi - initZeroFromMemset SMTScal{} = Nothing - - let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZeroZI :: SMTy a -> String -> String -> CompM () - initZeroZI SMTNil _ _ = return () - initZeroZI (SMTPair t1 t2) v vzi = do - initZeroZI t1 (v++".a") (vzi++".a") - initZeroZI t2 (v++".b") (vzi++".b") - initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZeroZI (SMTScal sty) v _ = case sty of - STI32 -> emit $ SAsg v (CELit "0") - STI64 -> emit $ SAsg v (CELit "0l") - STF32 -> emit $ SAsg v (CELit "0.0f") - STF64 -> emit $ SAsg v (CELit "0.0") - - let -- Initialise an uninitialised accumulation value, potentially already - -- with the addend, potentially to zero depending on the nature of the - -- projection. - -- 1. If the projection indexes only through dense monoids before - -- reaching SAPHere, the thing cannot be initialised to zero with - -- only an AcIdx; it would need to model a zero after the addend, - -- which is stupid and redundant. In this case, we return Left: - -- (accumulation value) (AcIdx value) (addend value). - -- The addend is copied, not consumed. (We can't reliably _always_ - -- consume it, so it's not worth trying to do it sometimes.) - -- 2. Otherwise, a sparse monoid is found along the way, and we can - -- initalise the dense prefix of the path to zero by setting the - -- indexed-through sparse value to a sparse zero. Afterwards, the - -- main recursion can proceed further. In this case, we return - -- Right: (accumulation value) (AcIdx value) - -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) - initZeroChunk :: SMTy a -> SAcPrj p a b - -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend - (String -> String -> CompM ()) -- zero initialisation of sparse chunk - initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of - -- reached target before the first sparse constructor - (t1 , SAPHere ) -> Left $ \v _ addend -> do - incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend - emit $ SAsg v (CELit addend) - -- sparse types - (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") - -- dense types - (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - f (v++".a") (i++".a") - initZeroZI t2 (v++".b") (i++".b") - (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do - initZeroZI t1 (v++".a") (i++".a") - f (v++".b") (i++".b") - (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do - initZeroArray n t1 v (i++".a.b") - linidxvar <- genName' "li" - emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) - f (v++".buf->xs["++linidxvar++"]") (i++".b") - where - applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i - applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i - + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do let -- Add a value (s) into an existing accumulation value (d). If a sparse -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () @@ -1160,67 +1073,55 @@ compile' env = \case accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend - - accumRef (SMTLEither ta tb) prj0 v i addend = do - let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' - SAPRight prj' -> initZeroChunk tb prj' - subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") - tagval = case prj0 of SAPLeft{} -> "1" - SAPRight{} -> "2" - ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend - SAPRight prj' -> accumRef tb prj' subv i addend - case chunkres of - Left densef -> do - ((), stmtsSet) <- scope $ densef subv i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) - stmtsAdd -- TODO: emit check for consistency of tags? - Right sparsef -> do - ((), stmtsInit) <- scope $ sparsef subv i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty - forM_ stmtsAdd emit + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend + + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - case initZeroChunk tj prj' of - Left densef -> do - ((), stmtsSet1) <- scope $ densef (v++".j") i addend - ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) - stmtsAdd1 - Right sparsef -> do - ((), stmtsInit1) <- scope $ sparsef (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i addend + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ v ++ ".buf" ++ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend nameidx <- compileAssign "acidx" env eidx nameval <- compileAssign "acval" env eval @@ -1234,6 +1135,9 @@ compile' env = \case return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] @@ -1247,6 +1151,7 @@ compile' env = \case return $ CEStruct name [] EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" diff --git a/src/Data.hs b/src/Data.hs index e86aaa6..e6978c8 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -8,12 +8,13 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl)) where +module Data (module Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare import Data.GADT.Show import Data.Some +import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) @@ -184,3 +185,8 @@ instance Applicative Bag where instance Semigroup (Bag t) where (<>) = BTwo instance Monoid (Bag t) where mempty = BNone + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) diff --git a/src/Example.hs b/src/Example.hs index d3f6d0d..b320ead 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -162,8 +162,7 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - _ -> undefined + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index b3576ce..ffc2929 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,7 @@ module Interpreter ( ) where import Control.Monad (foldM, join, when, forM_) +import Data.Bifunctor (bimap) import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity @@ -35,6 +36,7 @@ import Debug.Trace import Array import AST import AST.Pretty +import AST.Sparse.Types import Data import Interpreter.Rep @@ -158,14 +160,17 @@ interpret'Rec env = \case initval <- interpret' env e1 withAccum t (typeOf e2) initval $ \accum -> interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do + EAccum _ t p e1 sp e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparseD t p accum idx val + accumAddSparseD t p accum idx sp val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b @@ -216,6 +221,19 @@ zeroM typ zi = case typ of STF32 -> 0.0 STF64 -> 0.0 +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + addM :: SMTy t -> Rep t -> Rep t -> Rep t addM typ a b = case typ of SMTNil -> () @@ -256,15 +274,6 @@ withAccum t _ initval f = AcM $ do val <- readAc t accum return (out, val) -newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) -newAcZero typ zi = case typ of - SMTNil -> return () - SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi - SMTLEither{} -> newIORef Nothing - SMTMaybe _ -> newIORef Nothing - SMTArr _ t -> arrayMapM (newAcZero t) zi - SMTScal sty -> numericIsNum sty $ newIORef 0 - newAcDense :: SMTy a -> Rep a -> IO (RepAc a) newAcDense typ val = case typ of SMTNil -> return () @@ -274,22 +283,6 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (_, SAPHere) -> newAcDense typ val - - (SMTPair t1 t2, SAPFst prj') -> - (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) - (SMTPair t1 t2, SAPSnd prj') -> - (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val - - (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx - onehotArray :: Monad m => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" @@ -309,81 +302,67 @@ readAc typ val = case typ of SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val -accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () -accumAddDense typ ref val = case typ of - SMTNil -> return () - SMTPair t1 t2 -> do - accumAddDense t1 (fst ref) (fst val) - accumAddDense t2 (snd ref) (snd val) - SMTLEither{} -> - case val of - Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - SMTMaybe{} -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - SMTArr _ t1 -> - forM_ [0 .. arraySize ref - 1] $ \i -> - accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) - SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) - -accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s () -accumAddSparseD typ prj ref idx val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref val +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val - (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) (SMTArr n t1, SAPArrIdx prj') -> let (arrindex', idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ref linindex = toLinearIndex arrsh arrindex - in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val - -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref val - - (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val - - (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - - (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) - - (SMTArr n t1, SAPArrIdx prj') -> - let ((arrindex', ziarr), idx') = idx - arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = arrayShape ziarr - linindex = toLinearIndex arrsh arrindex - in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) +-- TODO: makeval is always 'error' now. Simplify? realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = -- Try modifying what's already in ref. The 'join' makes the snd diff --git a/src/Language.hs b/src/Language.hs index 63279df..4e6d604 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -17,6 +17,7 @@ module Language ( import Array import AST +import AST.Sparse.Types import AST.Types import CHAD.Types import Data @@ -176,7 +177,10 @@ with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum with a (n :-> b) = NEWith (knownMTy @t) a n b accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownMTy p a b c +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c (.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 92792b3..be98ccf 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM import Array import AST +import AST.Sparse.Types import CHAD.Types import Data @@ -76,7 +77,7 @@ data NExpr env t where -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a @@ -221,7 +222,7 @@ fromNamedExpr val = \case NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s diff --git a/src/Simplify.hs b/src/Simplify.hs index d3b850f..74b6601 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE QuasiQuotes #-} @@ -19,13 +21,14 @@ import Control.Monad (ap) import Data.Bifunctor (first) import Data.Function (fix) import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) import Debug.Trace import AST import AST.Count import AST.Pretty +import AST.Sparse.Types +import AST.UnMonoid (acPrjCompose) import Data import Simplify.TH @@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id) smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) smReconstruct core = SM (\ctx -> (Any False, ctx core)) -tellActed :: SM tenv tt env t () -tellActed = SM (\_ -> (Any True, ())) +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) -- more convenient in practice -acted :: SM tenv tt env t a -> SM tenv tt env t a +acted :: ActedMonad m => m a -> m a acted m = tellActed >> m within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) -acted' :: (Any, a) -> (Any, a) -acted' (_, x) = (Any True, x) - -liftActed :: (Any, a) -> SM tenv tt env t a -liftActed pair = SM $ \_ -> pair - simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) simplify' expr | scLogging ?config = do @@ -167,10 +176,10 @@ simplify'Rec = \case ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 (ELet _ rhs body) acc -> + EAccum _ t p e1 sp (ELet _ rhs body) acc -> acted $ simplify' $ ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) -- let () = e in () ~> e ELet _ e1 (ENil _) | STNil <- typeOf e1 -> @@ -194,6 +203,9 @@ simplify'Rec = \case EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1 + -- TODO: more array shape + EShape _ (EBuild _ _ e _) -> acted $ simplify' e + -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) @@ -222,23 +234,40 @@ simplify'Rec = \case acted $ simplify' $ EUnit ext (substInline (ENil ext) e) -- monoid rules - EAccum _ t p e1 e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2') + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') (acted $ return (ENil ext)) - (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2') + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\e -> acted $ return e) - (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -302,8 +331,9 @@ simplify'Rec = \case e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) pure (EWith ext t e1' e2') - EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] EError _ t s -> pure $ EError ext t s cheapExpr :: Expr x env t -> Bool @@ -353,8 +383,9 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith _ _ a b -> hasAdds a || hasAdds b ERecompute _ e -> hasAdds e - EAccum _ _ _ _ _ _ -> True + EAccum _ _ _ _ _ _ _ -> True EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e EPlus _ _ a b -> hasAdds a || hasAdds b EOneHot _ _ _ a b -> hasAdds a || hasAdds b EError _ _ _ -> False @@ -373,51 +404,161 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm dense env p a b where - OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b -deriving instance Show (OneHotTerm dense env p a b) - -simplifyOneHotTerm :: OneHotTerm dense env p a b - -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) - -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r) - -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do - val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 - case val1' of - EZero{} -> kzero - EOneHot _ t2 prj2 idx2 val2 - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do - tellActed -- record, whatever happens later, that we've modified something - concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k - _ -> case prj1 of - SAPHere -> ktriv val1 - _ -> k (OneHotTerm dense t1 prj1 idx1 val1) +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' (a', b') -> return $ EPair ext a' b' recogniseMonoid typ@(SMTLEither t1 t2) expr = case expr of - ELNil{} -> acted' $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e _ -> return expr recogniseMonoid typ@(SMTMaybe t1) expr = case expr of - ENothing{} -> acted' $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e _ -> return expr recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted' $ do + acted $ do e' <- recogniseMonoid t e return $ ELet ext e' $ @@ -426,61 +567,21 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = (ENil ext)) (EVar ext (fromSMTy t) IZ) recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext) + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SStillDense dense -> SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r -concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of - (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2) - (SAI_S, _, SAPHere) -> k prj2 idx2 - - (SAI_D, SMTPair a _, SAPFst prj1') -> - concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> - k (SAPFst prj12) idx12 - (SAI_S, SMTPair a _, SAPFst prj1') -> - concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SAI_D, SMTPair _ b, SAPSnd prj1') -> - concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> - k (SAPSnd prj12) idx12 - (SAI_S, SMTPair _ b, SAPSnd prj1') -> - concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - - (_, SMTLEither a _, SAPLeft prj1') -> - concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (_, SMTLEither _ b, SAPRight prj1') -> - concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (_, SMTMaybe a, SAPJust prj1') -> - concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - -- yes, twice the same code, but we need a concrete denseness indicator to - -- reduce AcIdx (the only difference between the dense and sparse versions is - -- whether there extra info also contains an array shape, and this code - -- handles the extra info uniformly) - (SAI_D, SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (SAI_S, SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - -reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a) +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) reduceAcIdx topty topprj e = case (topty, topprj) of (_, SAPHere) -> ENil ext (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) - (SMTLEither{}, SAPLeft{}) -> e - (SMTLEither{}, SAPRight{}) -> e - (SMTMaybe{}, SAPJust{}) -> e + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e (SMTArr _ t, SAPArrIdx p) -> eunPair e $ \_ e1 e2 -> EPair ext (efst e1) (reduceAcIdx t p e2) diff --git a/test/Main.hs b/test/Main.hs index 1b83a2e..d79e63f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -435,11 +435,22 @@ gen_neural = do lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) +term_build0 :: Ex '[TArr N0 R] R +term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx + term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx +term_build1_idx :: Ex '[TVec R] R +term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $ + let_ #n (snd_ (shape #x)) $ + idx0 $ sum1i $ + build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i)) + term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ @@ -502,22 +513,22 @@ tests_Compile = testGroup "Compile" ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) (pair 0.0 0.0) $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TMaybe (TPair R R)) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ + with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ nil @@ -556,9 +567,7 @@ tests_AD = testGroup "AD" ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ - idx0 $ - build SZ (shape #x) $ #idx :-> #x ! #idx + ,adTest "build0" term_build0 ,adTest "build1-sum" term_build1_sum @@ -566,6 +575,8 @@ tests_AD = testGroup "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx + ,adTest "build1-idx" term_build1_idx + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x -- cgit v1.2.3-70-g09d2