diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-16 23:21:55 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-16 23:21:55 +0200 |
commit | 2b1a40b5933b8b0dceaae744e5b70cb604822c9d (patch) | |
tree | 652d6d88efd2b0b4502819297333305cec5242c4 | |
parent | eed0f2999d6f6c8485ef53deb38f9d0a67b4f88e (diff) |
CHAD.hs compiles
-rw-r--r-- | src/AST.hs | 24 | ||||
-rw-r--r-- | src/AST/Accum.hs | 36 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 167 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 1 | ||||
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 18 | ||||
-rw-r--r-- | src/Interpreter.hs | 39 | ||||
-rw-r--r-- | src/Language.hs | 2 | ||||
-rw-r--r-- | src/Language/AST.hs | 2 | ||||
-rw-r--r-- | src/Simplify.hs | 106 |
10 files changed, 261 insertions, 136 deletions
@@ -92,12 +92,12 @@ data Expr x env t where -- accumulation effect on monoids EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t -- interface of abstract monoidal types ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) @@ -523,6 +523,14 @@ eunPair e k = (EFst ext (evar IZ)) (ESnd ext (evar IZ)) +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b elet rhs body | Dict <- styKnown (typeOf rhs) @@ -543,3 +551,15 @@ elcase e a b c evar :: KnownTy a => Idx env a -> Ex env a evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 1101cc0..158b4d9 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where @@ -32,21 +33,36 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) - AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) - AcIdx (APLeft p) (TLEither a b) = AcIdx p a - AcIdx (APRight p) (TLEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = +type data StillDense = AI_D | AI_S +data SStillDense dense where + SAI_D :: SStillDense AI_D + SAI_S :: SStillDense AI_S +deriving instance Show (SStillDense dense) + +type family AcIdx dense p t where + AcIdx dense APHere t = TNil + AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a + AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b + AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b) + AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b) + AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a + AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b + AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a + AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a) + AcIdx AI_S (APArrIdx p) (TArr n a) = -- ((index, shapes info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = + (AcIdx AI_S p a) + -- AcIdx AI_D (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AI_S (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AI_D p t +type AcIdxS p t = AcIdx AI_S p t + acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ac4d733..389dd5a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -105,7 +105,7 @@ plus (SMTArr _ t) a b = a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ diff --git a/src/CHAD.hs b/src/CHAD.hs index 7cd4c26..3dedec3 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -362,12 +362,6 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" -zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) -zeroTup SNil _ = ENil ext -zeroTup (t `SCons` env) w = - EPair ext (zeroTup env (WPop w)) - (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) - ----------------------------------- SPARSITY ----------------------------------- @@ -780,7 +774,7 @@ drev des accumMap (SpSparse sd) = subtape e1 sub' - (emaybe (evar IZ) + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) (inj2 (ENil ext)) (inj1 (weakenExpr (WCopy WSink) e2))) } @@ -794,7 +788,8 @@ drev des accumMap sd = \case (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) - in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx -> + EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI))) Idx2Me tupI -> Ret BTop @@ -1227,43 +1222,45 @@ drev des accumMap sd = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext)) - (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) } EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1286,35 +1283,35 @@ drev des accumMap sd = \case err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Sparse (TArr n (D2s t')) sd' - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYesR (SEYesR subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s sd t = @@ -1379,7 +1376,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des)) in - RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $ + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library @@ -1412,3 +1409,59 @@ drevPrimal des e chadD1EId :: SList STy l -> D1E l :~: l chadD1EId SNil = Refl chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl + +accumulateSparse + :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil) + -> Ex env TNil +accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of + (_, _, s) | Just Refl <- isDense topty s -> + accum WId SAPHere arg (ENil ext) + (_, SMTScal _, SpScal) -> + accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh + (_, _, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w))) + (_, _, SpAbsent) -> + ENil ext + (SAI_D, SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SAI_S, SMTPair{}, SpPair{}) -> + error "TODO: accumulating into pair inside coproduct unimplemented" + -- There are two different ways this can be accomplished: + -- 1. Ensure we have the requisite ZeroInfo here. This means that an + -- accum-mode variable reference will (if its incoming cotangent is + -- sparse enough) need to store some ZeroInfo fragments computed from + -- the primal (not necessarily the entire primal). Doing this properly, + -- i.e. not just storing a full D1 but only the required ZeroInfo + -- fragments, is possible and not too inefficient but a bit of + -- engineering again. + -- 2. When creating an accumulator, don't initialise it with a generic + -- EZero based on a ZeroInfo, but instead a special "deep zero" based on + -- probably a full D1. This deep zero also initialises Left/Right/Just + -- modelled after the primal. With this, an accumulation needs no zero + -- info whatsoever (!) under the assumption that it receives a cotangent + -- that is compatible with the primal it is propagated back to. + (_, SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (_, SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SAI_D, SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse dense t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $ + ENil ext + (SAI_S, SMTArr{}, SpArr{}) -> + error "TODO: accumulating into array inside coproduct unimplemented" + -- See the pair case above, same reasoning diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 261ddfe..130174a 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -15,7 +15,6 @@ import AST import AST.SplitLets import AST.Weaken.Auto import CHAD -import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index 8476712..888fed4 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of @@ -34,14 +32,12 @@ toTan typ primal der = case typ of (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 803a24a..b3576ce 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -162,7 +162,7 @@ interpret'Rec env = \case idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t p accum idx val + accumAddSparseD t p accum idx val EZero _ t ezi -> do zi <- interpret' env ezi return $ zeroM t zi @@ -239,7 +239,7 @@ addM typ a b = case typ of | otherwise -> error "Plus of inconsistently shaped arrays" SMTScal sty -> numericIsNum sty $ a + b -onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a onehotM SAPHere _ _ val = val onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) @@ -274,7 +274,7 @@ newAcDense typ val = case typ of SMTArr _ t1 -> arrayMapM (newAcDense t1) val SMTScal _ -> newIORef val -newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) +newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a) newAcSparse typ prj idx val = case (typ, prj) of (_, SAPHere) -> newAcDense typ val @@ -291,9 +291,9 @@ newAcSparse typ prj idx val = case (typ, prj) of (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = arrayShape ziarr @@ -329,7 +329,34 @@ accumAddDense typ ref val = case typ of accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s () +accumAddSparseD typ prj ref idx val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) + (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) + (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (newAcSparse t1 prj' idx val) + (\ac -> accumAddSparse t1 prj' ac idx val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s () accumAddSparse typ prj ref idx val = case (typ, prj) of (_, SAPHere) -> accumAddDense typ ref val diff --git a/src/Language.hs b/src/Language.hs index 7a780a0..63279df 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -175,7 +175,7 @@ recompute = NERecompute with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) with a (n :-> b) = NEWith (knownMTy @t) a n b -accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil accum p a b c = NEAccum knownMTy p a b c diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 7e074df..92792b3 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -76,7 +76,7 @@ data NExpr env t where -- accumulation effect on monoids NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a diff --git a/src/Simplify.hs b/src/Simplify.hs index e110206..d3b850f 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -226,19 +226,19 @@ simplify'Rec = \case e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1 e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2 acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2') (acted $ return (ENil ext)) (\e -> return (EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) + (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e EOneHot _ t p e1 e2 -> do e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') + simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2') (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) (\e -> acted $ return e) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) + (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) -- type-specific equations for plus EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> @@ -373,27 +373,27 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True -data OneHotTerm env p a b where - OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) +data OneHotTerm dense env p a b where + OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b +deriving instance Show (OneHotTerm dense env p a b) -simplifyOneHotTerm :: OneHotTerm env p a b +simplifyOneHotTerm :: OneHotTerm dense env p a b -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero) -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r) + -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r) -> SM tenv tt env t r -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do +simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1 case val1' of EZero{} -> kzero EOneHot _ t2 prj2 idx2 val2 | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do tellActed -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k + concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k _ -> case prj1 of SAPHere -> ktriv val1 - _ -> k (OneHotTerm t1 prj1 idx1 val1) + _ -> k (OneHotTerm dense t1 prj1 idx1 val1) -- | Recognises 'EZero' and 'EOneHot'. recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) @@ -433,52 +433,66 @@ recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of _ -> return e recogniseMonoid _ e = return e -concatOneHots :: SMTy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (SMTPair a _, SAPFst prj1') -> - concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> +concatOneHots :: SStillDense dense -> SMTy a + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b) + -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r +concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of + (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2) + (SAI_S, _, SAPHere) -> k prj2 idx2 + + (SAI_D, SMTPair a _, SAPFst prj1') -> + concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> + k (SAPFst prj12) idx12 + (SAI_S, SMTPair a _, SAPFst prj1') -> + concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ))) - (SMTPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + (SAI_D, SMTPair _ b, SAPSnd prj1') -> + concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> + k (SAPSnd prj12) idx12 + (SAI_S, SMTPair _ b, SAPSnd prj1') -> + concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) - (SMTLEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (SMTLEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 + (_, SMTLEither a _, SAPLeft prj1') -> + concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 + (_, SMTLEither _ b, SAPRight prj1') -> + concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - (SMTMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 + (_, SMTMaybe a, SAPJust prj1') -> + concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - (SMTArr _ a, SAPArrIdx prj1') -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + -- yes, twice the same code, but we need a concrete denseness indicator to + -- reduce AcIdx (the only difference between the dense and sparse versions is + -- whether there extra info also contains an array shape, and this code + -- handles the extra info uniformly) + (SAI_D, SMTArr _ a, SAPArrIdx prj1') -> + concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> + k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) + (SAI_S, SMTArr _ a, SAPArrIdx prj1') -> + concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither{}, SAPLeft{}) -> e + (SMTLEither{}, SAPRight{}) -> e + (SMTMaybe{}, SAPJust{}) -> e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) where -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t) + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) go t SAPHere _ e = makeZeroInfo t e go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) go SMTLEither{} _ _ _ = ENil ext go SMTMaybe{} _ _ _ = ENil ext go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext |