diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 101 |
1 files changed, 63 insertions, 38 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index e99859c..aedda5b 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -408,8 +408,8 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) -accumPromote :: forall t env sto proxy r. - proxy t +accumPromote :: forall dt env sto proxy r. + proxy dt -> Descr env sto -> OccEnv env -> (forall stoRepl envPro. @@ -425,8 +425,8 @@ accumPromote :: forall t env sto proxy r. -- the original environment. -> (forall shbinds. SList STy shbinds - -> (D2 t : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append envPro (D2 t : Append shbinds (D2AcE (Select env sto "accum")))) + -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append envPro (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) -- ^ A weakening that converts a computation in the -- revised environment to one in the original environment -- extended with some accumulators. @@ -434,9 +434,24 @@ accumPromote :: forall t env sto proxy r. -> r accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId) accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId) -accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k = - accumPromote pty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> +accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = + accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> case (t, sto, occ) of + -- Accumulators are left as-is + (_, SAccum, _) -> + k (storepl `DPush` (t, SAccum)) + mergesub + envpro + (\shbinds -> + autoWeak (#pro envpro $.. #d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero -> k (storepl `DPush` (t, SAccum)) (SENo mergesub) @@ -452,16 +467,30 @@ accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k = -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum arrn arrt) @(D2 t : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum arrn arrt) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) - (_, SAccum, _) -> - k (storepl `DPush` (t, SAccum)) - mergesub - envpro - (\shbinds -> _ (wf shbinds)) + -- Used "merge" values must either _be_ an array (and hence be caught by + -- the prior case), or contain _no_ arrays at all (TODO: generalise this) + (_, SMerge, Occ _ c) | c > Zero, containsTArr t -> + error $ "Closure variable of 'build'-like thing contains a composite type containing an array: " ++ show t - _ -> _ + -- What's left are normal "merge" values that don't contain arrays; those + -- remain as-is + (_, SMerge, _) -> + k (storepl `DPush` (t, SMerge)) + (SEYes mergesub) + envpro + wf + where + containsTArr :: STy t' -> Bool + containsTArr = \case + STNil -> False + STPair a b -> containsTArr a || containsTArr b + STEither a b -> containsTArr a || containsTArr b + STArr{} -> True + STScal{} -> False + STAccum{} -> error "An accumulator in merge storage?" -- | @env'@ is a subset of @env@: each element of @env@ is either included in -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). @@ -501,10 +530,6 @@ subenvNone :: SList f env -> Subenv env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) - subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) @@ -592,7 +617,7 @@ rebaseRetPair :: forall env b1 b2 env0 sto t f. rebaseRetPair descr b1 b2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = RetPair p sub (weakenExpr (autoWeak - (Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr))) + (#d (auto @'[D2 t]) $.. #b2 b2 $.. #b1 b1 $.. #tl (d2ace (select SAccum descr))) (#d :++: (#b2 :++: #tl)) (#d :++: ((#b2 :++: #b1) :++: #tl))) d) @@ -734,10 +759,10 @@ drev des = \case (weakenExpr wbody0' body1) subBoth (ELet ext - (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] - $.. Seg @"body" (bindingsBinds body0) - $.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0)) - $.. Seg @"tl" (d2ace (select SAccum des))) + (weakenExpr (autoWeak (#d (auto @'[D2 t]) + $.. #body (bindingsBinds body0) + $.. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0)) + $.. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) (#d :++: #body :++: #rhs :++: #tl)) body2') $ @@ -842,12 +867,12 @@ drev des = \case ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ ELet ext - (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] - $.. Seg @"a0" (bindingsBinds a0) - $.. Seg @"prea0" prerebinds - $.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil) - $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0) - $.. Seg @"tl" (d2ace (select SAccum des))) + (weakenExpr (autoWeak (#d (auto @'[D2 t]) + $.. #a0 (bindingsBinds a0) + $.. #prea0 prerebinds + $.. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) + $.. #binds (tPrimal `SCons` bindingsBinds e0) + $.. #tl (d2ace (select SAccum des))) (#d :++: #a0 :++: #tl) (#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) a2') $ @@ -861,12 +886,12 @@ drev des = \case ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ ELet ext - (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] - $.. Seg @"b0" (bindingsBinds b0) - $.. Seg @"preb0" prerebinds - $.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil) - $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0) - $.. Seg @"tl" (d2ace (select SAccum des))) + (weakenExpr (autoWeak (#d (auto @'[D2 t]) + $.. #b0 (bindingsBinds b0) + $.. #preb0 prerebinds + $.. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) + $.. #binds (tPrimal `SCons` bindingsBinds e0) + $.. #tl (d2ace (select SAccum des))) (#d :++: #b0 :++: #tl) (#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) b2') $ @@ -920,10 +945,10 @@ drev des = \case Ret (bconcat (ne0 `BPush` (tIx, ne1)) (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) (EBuild1 ext - (weakenExpr (autoWeak (Seg @"ve0" (bindingsBinds ve0) - $.. Seg' @"i" @'[TIx] - $.. Seg @"ne0" (bindingsBinds ne0) - $.. Seg @"tl" (sD1eEnv des)) + (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) + $.. #i (auto @'[TIx]) + $.. #ne0 (bindingsBinds ne0) + $.. #tl (sD1eEnv des)) (#ne0 :++: #tl) ((#ve0 :++: #i :++: #ne0) :++: #tl)) ne1) |