summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Accum.hs52
-rw-r--r--src/CHAD/EnvDescr.hs20
-rw-r--r--src/CHAD/Top.hs53
-rw-r--r--src/CHAD/Types.hs45
-rw-r--r--src/CHAD/Types/ToTan.hs18
5 files changed, 130 insertions, 58 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index d8a71b5..7212232 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -1,18 +1,54 @@
-{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+-- | TODO this module is a grab-bag of random utility functions that are shared
+-- between CHAD and CHAD.Top.
module CHAD.Accum where
import AST
import CHAD.Types
import Data
+import AST.Env
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators SNil e = e
-makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t =
- makeAccumulators envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e
+d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t))
+d2deepZeroInfo STNil _ = ENil ext
+d2deepZeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2)
+d2deepZeroInfo (STEither a b) e =
+ ECase ext e
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STLEither a b) e =
+ elcase e
+ (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b)))
+ (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+ (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ)))
+d2deepZeroInfo (STMaybe a) e =
+ emaybe e
+ (ENothing ext (tDeepZeroInfo (d2M a)))
+ (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ)))
+d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
+d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
+d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
@@ -25,3 +61,7 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs
index 4c287d7..49ae0e6 100644
--- a/src/CHAD/EnvDescr.hs
+++ b/src/CHAD/EnvDescr.hs
@@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env'
-> r)
-> r
subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k =
+subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
- SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e)
- SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e)
+ SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e)
+ SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e)
+ SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e)
subDescr (des `DPush` (_, _, sto)) (SENo sub) k =
subDescr des sub $ \des' submerge subaccum subd1e ->
case sto of
@@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des
select s@SAccum (DPush des (_, _, SDiscr)) = select s des
select s@SMerge (DPush des (_, _, SDiscr)) = select s des
select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des)
+
+selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s)
+selectSub _ DTop = SETop
+selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des)
+selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des)
+selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 261ddfe..484779e 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -12,6 +12,8 @@ module CHAD.Top where
import Analysis.Identity
import AST
+import AST.Env
+import AST.Sparse
import AST.SplitLets
import AST.Weaken.Auto
import CHAD
@@ -44,36 +46,22 @@ accumDescr (t `SCons` env) k = accumDescr env $ \des ->
if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
else k (des `DPush` (t, Nothing, SMerge))
-d1Identity :: STy t -> D1 t :~: t
-d1Identity = \case
- STNil -> Refl
- STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STMaybe t | Refl <- d1Identity t -> Refl
- STArr _ t | Refl <- d1Identity t -> Refl
- STScal _ -> Refl
- STAccum{} -> error "Accumulators not allowed in input program"
-
-d1eIdentity :: SList STy env -> D1E env :~: env
-d1eIdentity SNil = Refl
-d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
-
reassembleD2E :: Descr env sto
+ -> D1E env :> env'
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
-reassembleD2E DTop _ = ENil ext
-reassembleD2E (des `DPush` (_, _, SAccum)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
- (ESnd ext (EVar ext (typeOf e) IZ))))
- (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (_, _, SMerge)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
- (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
- (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t)
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
@@ -83,21 +71,22 @@ chad config env (term :: Ex env t)
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
- makeAccumulators (select SAccum descr) $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr VarMap.empty term')) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
- (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
- (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term')
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
where
term' = identityAnalysis env (splitLets term)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 974669d..44ac20e 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -1,8 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
+import AST.Accum
import AST.Types
import Data
@@ -18,11 +20,11 @@ type family D1 t where
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
D2 (TEither a b) = TLEither (D2 a) (D2 b)
D2 (TLEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
- D2 (TArr n t) = TMaybe (TArr n (D2 t))
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
@@ -60,11 +62,11 @@ d1e (t `SCons` env) = d1 t `SCons` d1e env
d2M :: STy t -> SMTy (D2 t)
d2M STNil = SMTNil
-d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
+d2M (STPair a b) = SMTPair (d2M a) (d2M b)
d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
d2M (STMaybe t) = SMTMaybe (d2M t)
-d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
+d2M (STArr n t) = SMTArr n (d2M t)
d2M (STScal t) = case t of
STI32 -> SMTNil
STI64 -> SMTNil
@@ -95,6 +97,8 @@ data CHADConfig = CHADConfig
chcCaseArrayAccum :: Bool
, -- | Introduce top-level arguments containing arrays in accumulator mode.
chcArgArrayAccum :: Bool
+ , -- | Place with-blocks around array variable scopes, and redirect accumulations there.
+ chcSmartWith :: Bool
}
deriving (Show)
@@ -103,12 +107,14 @@ defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
, chcArgArrayAccum = False
+ , chcSmartWith = False
}
chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
, chcCaseArrayAccum = True
- , chcArgArrayAccum = True }
+ , chcArgArrayAccum = True
+ , chcSmartWith = True }
------------------------------------ LEMMAS ------------------------------------
@@ -116,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True
indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
indexTupD1Id SZ = Refl
indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
+
+lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil
+lemZeroInfoScal STI32 = Refl
+lemZeroInfoScal STI64 = Refl
+lemZeroInfoScal STF32 = Refl
+lemZeroInfoScal STF64 = Refl
+lemZeroInfoScal STBool = Refl
+
+lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil
+lemDeepZeroInfoScal STI32 = Refl
+lemDeepZeroInfoScal STI64 = Refl
+lemDeepZeroInfoScal STF32 = Refl
+lemDeepZeroInfoScal STF64 = Refl
+lemDeepZeroInfoScal STBool = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index 8476712..888fed4 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -19,9 +19,7 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
toTan typ primal der = case typ of
STNil -> der
- STPair t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal
STEither t1 t2 -> case der of
Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
Just d -> case (primal, d) of
@@ -34,14 +32,12 @@ toTan typ primal der = case typ of
(Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
_ -> error "Primal and cotangent disagree on LEither alternative"
STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t -> case der of
- Nothing -> arrayMap (zeroTan t) primal
- Just d
- | arrayShape primal == arrayShape d ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
+ STArr _ t
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"