summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-16 23:21:55 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-16 23:21:55 +0200
commit2b1a40b5933b8b0dceaae744e5b70cb604822c9d (patch)
tree652d6d88efd2b0b4502819297333305cec5242c4
parenteed0f2999d6f6c8485ef53deb38f9d0a67b4f88e (diff)
CHAD.hs compiles
-rw-r--r--src/AST.hs24
-rw-r--r--src/AST/Accum.hs36
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/CHAD.hs167
-rw-r--r--src/CHAD/Top.hs1
-rw-r--r--src/CHAD/Types/ToTan.hs18
-rw-r--r--src/Interpreter.hs39
-rw-r--r--src/Language.hs2
-rw-r--r--src/Language/AST.hs2
-rw-r--r--src/Simplify.hs106
10 files changed, 261 insertions, 136 deletions
diff --git a/src/AST.hs b/src/AST.hs
index b2ddbb4..c24e3e7 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -92,12 +92,12 @@ data Expr x env t where
-- accumulation effect on monoids
EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
- EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
- EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
-- interface of abstract monoidal types
ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
@@ -523,6 +523,14 @@ eunPair e k =
(EFst ext (evar IZ))
(ESnd ext (evar IZ))
+efst :: Ex env (TPair a b) -> Ex env a
+efst (EPair _ e1 _) = e1
+efst e = EFst ext e
+
+esnd :: Ex env (TPair a b) -> Ex env b
+esnd (EPair _ _ e2) = e2
+esnd e = ESnd ext e
+
elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b
elet rhs body
| Dict <- styKnown (typeOf rhs)
@@ -543,3 +551,15 @@ elcase e a b c
evar :: KnownTy a => Idx env a -> Ex env a
evar = EVar ext knownTy
+
+makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
+ where
+ -- invariant: expression argument is duplicable
+ go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
+ go SMTNil _ = ENil ext
+ go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
+ go SMTLEither{} _ = ENil ext
+ go SMTMaybe{} _ = ENil ext
+ go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
+ go SMTScal{} _ = ENil ext
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 1101cc0..158b4d9 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
@@ -32,21 +33,36 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type family AcIdx p t where
- AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
- AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
- AcIdx (APLeft p) (TLEither a b) = AcIdx p a
- AcIdx (APRight p) (TLEither a b) = AcIdx p b
- AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) =
+type data StillDense = AI_D | AI_S
+data SStillDense dense where
+ SAI_D :: SStillDense AI_D
+ SAI_S :: SStillDense AI_S
+deriving instance Show (SStillDense dense)
+
+type family AcIdx dense p t where
+ AcIdx dense APHere t = TNil
+ AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a
+ AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b
+ AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b)
+ AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b)
+ AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a
+ AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b
+ AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a
+ AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a)
+ AcIdx AI_S (APArrIdx p) (TArr n a) =
-- ((index, shapes info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx p a)
- -- AcIdx (APArrSlice m) (TArr n a) =
+ (AcIdx AI_S p a)
+ -- AcIdx AI_D (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AI_S (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+type AcIdxD p t = AcIdx AI_D p t
+type AcIdxS p t = AcIdx AI_S p t
+
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ac4d733..389dd5a 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -105,7 +105,7 @@ plus (SMTArr _ t) a b =
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
ELet ext arg $
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 7cd4c26..3dedec3 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -362,12 +362,6 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0))
-zeroTup SNil _ = ENil ext
-zeroTup (t `SCons` env) w =
- EPair ext (zeroTup env (WPop w))
- (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
-
----------------------------------- SPARSITY -----------------------------------
@@ -780,7 +774,7 @@ drev des accumMap (SpSparse sd) =
subtape
e1
sub'
- (emaybe (evar IZ)
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
(inj2 (ENil ext))
(inj1 (weakenExpr (WCopy WSink) e2)))
}
@@ -794,7 +788,8 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
(let ty = applySparse sd (d2M t)
- in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx ->
+ EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -1227,43 +1222,45 @@ drev des accumMap sd = \case
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere)
- (EPair ext (EPair ext (EVar ext tIxN (IS IZ))
- (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext))
- (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
}
EShape _ e
- -- Allowed to ignore e2 here because the output of EShape is discrete,
- -- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des accumMap e
- , STArr n _ <- typeOf e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret e0
- subtape
- (EShape ext e1)
- (subenvNone (select SMerge des))
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
(SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 (STArr n t)) IZ))
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -1286,35 +1283,35 @@ drev des accumMap sd = \case
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
- deriv_extremum :: ScalIsNumeric t' ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Sparse (TArr n (D2s t')) sd'
- -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t'))
- deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , at@(STArr (SS n) t@(STScal st)) <- typeOf e
- , let at' = STArr n t
- , let tIxN = tTup (sreplicate (SS n) tIx) =
- Ret (e0 `BPush` (at, e1)
- `BPush` (at', extremum (EVar ext at IZ)))
- (SEYesR (SEYesR subtape))
- (EVar ext at' IZ)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext
- (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
- eif (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
- (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
- (ezeroD2 t))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 at') IZ))
-
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
data RetScoped env0 sto a s sd t =
@@ -1379,7 +1376,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
&. #ac (auto1 @(TAccum (D2 a)))
&. #tl (d2ace (select SAccum des))
in
- RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $
weakenExpr (autoWeak library
@@ -1412,3 +1409,59 @@ drevPrimal des e
chadD1EId :: SList STy l -> D1E l :~: l
chadD1EId SNil = Refl
chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl
+
+accumulateSparse
+ :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of
+ (_, _, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere arg (ENil ext)
+ (_, SMTScal _, SpScal) ->
+ accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh
+ (_, _, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, _, SpAbsent) ->
+ ENil ext
+ (SAI_D, SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SAI_S, SMTPair{}, SpPair{}) ->
+ error "TODO: accumulating into pair inside coproduct unimplemented"
+ -- There are two different ways this can be accomplished:
+ -- 1. Ensure we have the requisite ZeroInfo here. This means that an
+ -- accum-mode variable reference will (if its incoming cotangent is
+ -- sparse enough) need to store some ZeroInfo fragments computed from
+ -- the primal (not necessarily the entire primal). Doing this properly,
+ -- i.e. not just storing a full D1 but only the required ZeroInfo
+ -- fragments, is possible and not too inefficient but a bit of
+ -- engineering again.
+ -- 2. When creating an accumulator, don't initialise it with a generic
+ -- EZero based on a ZeroInfo, but instead a special "deep zero" based on
+ -- probably a full D1. This deep zero also initialises Left/Right/Just
+ -- modelled after the primal. With this, an accumulation needs no zero
+ -- info whatsoever (!) under the assumption that it receives a cotangent
+ -- that is compatible with the primal it is propagated back to.
+ (_, SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (_, SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SAI_D, SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse dense t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $
+ ENil ext
+ (SAI_S, SMTArr{}, SpArr{}) ->
+ error "TODO: accumulating into array inside coproduct unimplemented"
+ -- See the pair case above, same reasoning
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 261ddfe..130174a 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -15,7 +15,6 @@ import AST
import AST.SplitLets
import AST.Weaken.Auto
import CHAD
-import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index 8476712..888fed4 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
- STPair t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal
STEither t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just d -> case (primal, d) of
@@ -34,14 +32,12 @@ toTan typ primal der = case typ of
(Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
_ -> error "Primal and cotangent disagree on LEither alternative"
STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t -> case der of
- Nothing -> arrayMap (zeroTan t) primal
- Just d
- | arrayShape primal == arrayShape d ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
+ STArr _ t
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 803a24a..b3576ce 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -162,7 +162,7 @@ interpret'Rec env = \case
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t p accum idx val
+ accumAddSparseD t p accum idx val
EZero _ t ezi -> do
zi <- interpret' env ezi
return $ zeroM t zi
@@ -239,7 +239,7 @@ addM typ a b = case typ of
| otherwise -> error "Plus of inconsistently shaped arrays"
SMTScal sty -> numericIsNum sty $ a + b
-onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a
+onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a
onehotM SAPHere _ _ val = val
onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx))
onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val)
@@ -274,7 +274,7 @@ newAcDense typ val = case typ of
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
-newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a)
+newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a)
newAcSparse typ prj idx val = case (typ, prj) of
(_, SAPHere) -> newAcDense typ val
@@ -291,9 +291,9 @@ newAcSparse typ prj idx val = case (typ, prj) of
(SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx
onehotArray :: Monad m
- => (Rep (AcIdx p a) -> m v) -- ^ the "one"
+ => (Rep (AcIdxS p a) -> m v) -- ^ the "one"
-> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
- -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v)
onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
let arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = arrayShape ziarr
@@ -329,7 +329,34 @@ accumAddDense typ ref val = case typ of
accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s ()
+accumAddSparseD typ prj ref idx val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
+ realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
+ (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
+ Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
+ (SMTLEither _ t2, SAPRight prj') ->
+ realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
+ (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
+ Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
+
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+ (\ac -> accumAddSparse t1 prj' ac idx val)
+
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let (arrindex', idx') = idx
+ arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = arrayShape ref
+ linindex = toLinearIndex arrsh arrindex
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val
+
+accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s ()
accumAddSparse typ prj ref idx val = case (typ, prj) of
(_, SAPHere) -> accumAddDense typ ref val
diff --git a/src/Language.hs b/src/Language.hs
index 7a780a0..63279df 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -175,7 +175,7 @@ recompute = NERecompute
with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t)
with a (n :-> b) = NEWith (knownMTy @t) a n b
-accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
accum p a b c = NEAccum knownMTy p a b c
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 7e074df..92792b3 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -76,7 +76,7 @@ data NExpr env t where
-- accumulation effect on monoids
NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
- NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
-- partiality
NEError :: STy a -> String -> NExpr env a
diff --git a/src/Simplify.hs b/src/Simplify.hs
index e110206..d3b850f 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -226,19 +226,19 @@ simplify'Rec = \case
e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1
e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2
acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc
- simplifyOneHotTerm (OneHotTerm t p e1' e2')
+ simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2')
(acted $ return (ENil ext))
(\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
+ (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
- simplifyOneHotTerm (OneHotTerm t p e1' e2')
+ simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2')
(acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
(\e -> acted $ return e)
- (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
+ (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
-- type-specific equations for plus
EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
@@ -373,27 +373,27 @@ checkAccumInScope = \case SNil -> False
check (STScal _) = False
check STAccum{} = True
-data OneHotTerm env p a b where
- OneHotTerm :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env b -> OneHotTerm env p a b
-deriving instance Show (OneHotTerm env p a b)
+data OneHotTerm dense env p a b where
+ OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b
+deriving instance Show (OneHotTerm dense env p a b)
-simplifyOneHotTerm :: OneHotTerm env p a b
+simplifyOneHotTerm :: OneHotTerm dense env p a b
-> SM tenv tt env t r -- ^ Zero case (onehot is actually zero)
-> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> SM tenv tt env t r)
+ -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r)
-> SM tenv tt env t r
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 val1) kzero ktriv k = do
+simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do
val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1
case val1' of
EZero{} -> kzero
EOneHot _ t2 prj2 idx2 val2
| Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do
tellActed -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val2) kzero ktriv k
+ concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k
_ -> case prj1 of
SAPHere -> ktriv val1
- _ -> k (OneHotTerm t1 prj1 idx1 val1)
+ _ -> k (OneHotTerm dense t1 prj1 idx1 val1)
-- | Recognises 'EZero' and 'EOneHot'.
recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
@@ -433,52 +433,66 @@ recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
_ -> return e
recogniseMonoid _ e = return e
-concatOneHots :: SMTy a
- -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
-concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
- (_, SAPHere) -> k prj2 idx2
-
- (SMTPair a _, SAPFst prj1') ->
- concatOneHots a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+concatOneHots :: SStillDense dense -> SMTy a
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b)
+ -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r
+concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of
+ (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2)
+ (SAI_S, _, SAPHere) -> k prj2 idx2
+
+ (SAI_D, SMTPair a _, SAPFst prj1') ->
+ concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 ->
+ k (SAPFst prj12) idx12
+ (SAI_S, SMTPair a _, SAPFst prj1') ->
+ concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ)))
- (SMTPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ (SAI_D, SMTPair _ b, SAPSnd prj1') ->
+ concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 ->
+ k (SAPSnd prj12) idx12
+ (SAI_S, SMTPair _ b, SAPSnd prj1') ->
+ concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
- (SMTLEither a _, SAPLeft prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (SMTLEither _ b, SAPRight prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
+ (_, SMTLEither a _, SAPLeft prj1') ->
+ concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
+ (_, SMTLEither _ b, SAPRight prj1') ->
+ concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
- (SMTMaybe a, SAPJust prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
+ (_, SMTMaybe a, SAPJust prj1') ->
+ concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
- (SMTArr _ a, SAPArrIdx prj1') ->
- concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ -- yes, twice the same code, but we need a concrete denseness indicator to
+ -- reduce AcIdx (the only difference between the dense and sparse versions is
+ -- whether there extra info also contains an array shape, and this code
+ -- handles the extra info uniformly)
+ (SAI_D, SMTArr _ a, SAPArrIdx prj1') ->
+ concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
+ (SAI_S, SMTArr _ a, SAPArrIdx prj1') ->
+ concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
-zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a)
+reduceAcIdx topty topprj e = case (topty, topprj) of
+ (_, SAPHere) -> ENil ext
+ (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
+ (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
+ (SMTLEither{}, SAPLeft{}) -> e
+ (SMTLEither{}, SAPRight{}) -> e
+ (SMTMaybe{}, SAPJust{}) -> e
+ (SMTArr _ t, SAPArrIdx p) ->
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (efst e1) (reduceAcIdx t p e2)
+
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
where
-- invariant: AcIdx expression is duplicable
- go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
go t SAPHere _ e = makeZeroInfo t e
go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
go SMTLEither{} _ _ _ = ENil ext
go SMTMaybe{} _ _ _ = ENil ext
go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)
-
-makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
-makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ)
- where
- -- invariant: expression argument is duplicable
- go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t)
- go SMTNil _ = ENil ext
- go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e))
- go SMTLEither{} _ = ENil ext
- go SMTMaybe{} _ = ENil ext
- go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
- go SMTScal{} _ = ENil ext