summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST/Weaken/Auto.hs25
-rw-r--r--src/CHAD.hs101
2 files changed, 70 insertions, 56 deletions
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index 0deec71..eecb6f3 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -18,8 +18,7 @@
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module AST.Weaken.Auto (
autoWeak,
- GivenSegment(..),
- ($..),
+ ($..), auto,
Layout(..),
) where
@@ -51,25 +50,15 @@ data SSegments (segments :: [(Symbol, [t])]) where
SSegNil :: SSegments '[]
SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
-class ToSegments k a | a -> k where
- type SegmentsOf k a :: [(Symbol, [k])]
- toSegments :: a -> SSegments (SegmentsOf k a)
+instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where
+ fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil
-instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where
- type SegmentsOf k (SSegments segments) = segments
- toSegments = id
-
-data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts)
- | (KnownSymbol name, KnownListSpine ts) => Seg'
-
-instance ToSegments k (GivenSegment name (ts :: [k])) where
- type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)]
- toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil
- toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil
+auto :: KnownListSpine list => SList (Const ()) list
+auto = knownListSpine
infixr $..
-($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b))
-x $.. y = ssegmentsAppend (toSegments x) (toSegments y)
+($..) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2)
+($..) = ssegmentsAppend
where
ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
ssegmentsAppend SSegNil l2 = l2
diff --git a/src/CHAD.hs b/src/CHAD.hs
index e99859c..aedda5b 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -408,8 +408,8 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
-accumPromote :: forall t env sto proxy r.
- proxy t
+accumPromote :: forall dt env sto proxy r.
+ proxy dt
-> Descr env sto
-> OccEnv env
-> (forall stoRepl envPro.
@@ -425,8 +425,8 @@ accumPromote :: forall t env sto proxy r.
-- the original environment.
-> (forall shbinds.
SList STy shbinds
- -> (D2 t : Append shbinds (D2AcE (Select env stoRepl "accum")))
- :> Append envPro (D2 t : Append shbinds (D2AcE (Select env sto "accum"))))
+ -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append envPro (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
-- ^ A weakening that converts a computation in the
-- revised environment to one in the original environment
-- extended with some accumulators.
@@ -434,9 +434,24 @@ accumPromote :: forall t env sto proxy r.
-> r
accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId)
accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId)
-accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k =
- accumPromote pty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf ->
+accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
+ accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf ->
case (t, sto, occ) of
+ -- Accumulators are left as-is
+ (_, SAccum, _) ->
+ k (storepl `DPush` (t, SAccum))
+ mergesub
+ envpro
+ (\shbinds ->
+ autoWeak (#pro envpro $.. #d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum descr)))
+ (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
+ (#pro :++: #d :++: #shb :++: #acc :++: #tl)
+ .> WCopy (wf shbinds)
+ .> autoWeak (#d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum storepl)))
+ (#d :++: #shb :++: #acc :++: #tl)
+ (#acc :++: (#d :++: #shb :++: #tl)))
+
+ -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro
(STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero ->
k (storepl `DPush` (t, SAccum))
(SENo mergesub)
@@ -452,16 +467,30 @@ accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k =
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum arrn arrt) @(D2 t : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum arrn arrt) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
- (_, SAccum, _) ->
- k (storepl `DPush` (t, SAccum))
- mergesub
- envpro
- (\shbinds -> _ (wf shbinds))
+ -- Used "merge" values must either _be_ an array (and hence be caught by
+ -- the prior case), or contain _no_ arrays at all (TODO: generalise this)
+ (_, SMerge, Occ _ c) | c > Zero, containsTArr t ->
+ error $ "Closure variable of 'build'-like thing contains a composite type containing an array: " ++ show t
- _ -> _
+ -- What's left are normal "merge" values that don't contain arrays; those
+ -- remain as-is
+ (_, SMerge, _) ->
+ k (storepl `DPush` (t, SMerge))
+ (SEYes mergesub)
+ envpro
+ wf
+ where
+ containsTArr :: STy t' -> Bool
+ containsTArr = \case
+ STNil -> False
+ STPair a b -> containsTArr a || containsTArr b
+ STEither a b -> containsTArr a || containsTArr b
+ STArr{} -> True
+ STScal{} -> False
+ STAccum{} -> error "An accumulator in merge storage?"
-- | @env'@ is a subset of @env@: each element of @env@ is either included in
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
@@ -501,10 +530,6 @@ subenvNone :: SList f env -> Subenv env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
-subenvAll :: SList f env -> Subenv env env
-subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
-
subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
@@ -592,7 +617,7 @@ rebaseRetPair :: forall env b1 b2 env0 sto t f.
rebaseRetPair descr b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
RetPair p sub (weakenExpr (autoWeak
- (Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr)))
+ (#d (auto @'[D2 t]) $.. #b2 b2 $.. #b1 b1 $.. #tl (d2ace (select SAccum descr)))
(#d :++: (#b2 :++: #tl))
(#d :++: ((#b2 :++: #b1) :++: #tl)))
d)
@@ -734,10 +759,10 @@ drev des = \case
(weakenExpr wbody0' body1)
subBoth
(ELet ext
- (weakenExpr (autoWeak (Seg' @"d" @'[D2 t]
- $.. Seg @"body" (bindingsBinds body0)
- $.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0))
- $.. Seg @"tl" (d2ace (select SAccum des)))
+ (weakenExpr (autoWeak (#d (auto @'[D2 t])
+ $.. #body (bindingsBinds body0)
+ $.. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0))
+ $.. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
(#d :++: #body :++: #rhs :++: #tl))
body2') $
@@ -842,12 +867,12 @@ drev des = \case
ELet ext
(EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
ELet ext
- (weakenExpr (autoWeak (Seg' @"d" @'[D2 t]
- $.. Seg @"a0" (bindingsBinds a0)
- $.. Seg @"prea0" prerebinds
- $.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
- $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)
- $.. Seg @"tl" (d2ace (select SAccum des)))
+ (weakenExpr (autoWeak (#d (auto @'[D2 t])
+ $.. #a0 (bindingsBinds a0)
+ $.. #prea0 prerebinds
+ $.. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+ $.. #binds (tPrimal `SCons` bindingsBinds e0)
+ $.. #tl (d2ace (select SAccum des)))
(#d :++: #a0 :++: #tl)
(#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2') $
@@ -861,12 +886,12 @@ drev des = \case
ELet ext
(EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
ELet ext
- (weakenExpr (autoWeak (Seg' @"d" @'[D2 t]
- $.. Seg @"b0" (bindingsBinds b0)
- $.. Seg @"preb0" prerebinds
- $.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
- $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)
- $.. Seg @"tl" (d2ace (select SAccum des)))
+ (weakenExpr (autoWeak (#d (auto @'[D2 t])
+ $.. #b0 (bindingsBinds b0)
+ $.. #preb0 prerebinds
+ $.. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+ $.. #binds (tPrimal `SCons` bindingsBinds e0)
+ $.. #tl (d2ace (select SAccum des)))
(#d :++: #b0 :++: #tl)
(#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2') $
@@ -920,10 +945,10 @@ drev des = \case
Ret (bconcat (ne0 `BPush` (tIx, ne1))
(fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
(EBuild1 ext
- (weakenExpr (autoWeak (Seg @"ve0" (bindingsBinds ve0)
- $.. Seg' @"i" @'[TIx]
- $.. Seg @"ne0" (bindingsBinds ne0)
- $.. Seg @"tl" (sD1eEnv des))
+ (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
+ $.. #i (auto @'[TIx])
+ $.. #ne0 (bindingsBinds ne0)
+ $.. #tl (sD1eEnv des))
(#ne0 :++: #tl)
((#ve0 :++: #i :++: #ne0) :++: #tl))
ne1)