diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-02 22:14:47 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-02 22:14:47 +0200 | 
| commit | bddd67f4bfff1dcfac2cc0ad4d824788afed91e5 (patch) | |
| tree | a530787c51a87294601109986796ded4f9988fa9 | |
| parent | ad2aeb613af02c9f59b0301540473742914b47e7 (diff) | |
accumPromote
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 25 | ||||
| -rw-r--r-- | src/CHAD.hs | 101 | 
2 files changed, 70 insertions, 56 deletions
| diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 0deec71..eecb6f3 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -18,8 +18,7 @@  {-# OPTIONS_GHC -Wno-partial-type-signatures #-}  module AST.Weaken.Auto (    autoWeak, -  GivenSegment(..), -  ($..), +  ($..), auto,    Layout(..),  ) where @@ -51,25 +50,15 @@ data SSegments (segments :: [(Symbol, [t])]) where    SSegNil :: SSegments '[]    SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -class ToSegments k a | a -> k where -  type SegmentsOf k a :: [(Symbol, [k])] -  toSegments :: a -> SSegments (SegmentsOf k a) +instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +  fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil -instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where -  type SegmentsOf k (SSegments segments) = segments -  toSegments = id - -data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts) -                          | (KnownSymbol name, KnownListSpine ts) => Seg' - -instance ToSegments k (GivenSegment name (ts :: [k])) where -  type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)] -  toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil -  toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil +auto :: KnownListSpine list => SList (Const ()) list +auto = knownListSpine  infixr $.. -($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b)) -x $.. y = ssegmentsAppend (toSegments x) (toSegments y) +($..) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +($..) = ssegmentsAppend    where      ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)      ssegmentsAppend SSegNil l2 = l2 diff --git a/src/CHAD.hs b/src/CHAD.hs index e99859c..aedda5b 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -408,8 +408,8 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))  zeroTup SNil = ENil ext  zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) -accumPromote :: forall t env sto proxy r. -                proxy t +accumPromote :: forall dt env sto proxy r. +                proxy dt               -> Descr env sto               -> OccEnv env               -> (forall stoRepl envPro. @@ -425,8 +425,8 @@ accumPromote :: forall t env sto proxy r.                        -- the original environment.                   -> (forall shbinds.                              SList STy shbinds -                         -> (D2 t : Append shbinds (D2AcE (Select env stoRepl "accum"))) -                            :> Append envPro (D2 t : Append shbinds (D2AcE (Select env sto "accum")))) +                         -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) +                            :> Append envPro (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))                        -- ^ A weakening that converts a computation in the                        -- revised environment to one in the original environment                        -- extended with some accumulators. @@ -434,9 +434,24 @@ accumPromote :: forall t env sto proxy r.               -> r  accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId)  accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId) -accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k = -  accumPromote pty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> +accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = +  accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf ->      case (t, sto, occ) of +      -- Accumulators are left as-is +      (_, SAccum, _) -> +        k (storepl `DPush` (t, SAccum)) +          mergesub +          envpro +          (\shbinds -> +            autoWeak (#pro envpro $.. #d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum descr))) +                     (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) +                     (#pro :++: #d :++: #shb :++: #acc :++: #tl) +            .> WCopy (wf shbinds) +            .> autoWeak (#d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum storepl))) +                        (#d :++: #shb :++: #acc :++: #tl) +                        (#acc :++: (#d :++: #shb :++: #tl))) + +      -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro        (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero ->          k (storepl `DPush` (t, SAccum))            (SENo mergesub) @@ -452,16 +467,30 @@ accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k =              -- goal:                        |                                                                 ARE EQUAL  ||              --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))              WCopy (wf shbinds) -            .> WPick @(TAccum arrn arrt) @(D2 t : shbinds) (Const () `SCons` shbindsC) +            .> WPick @(TAccum arrn arrt) @(D2 dt : shbinds) (Const () `SCons` shbindsC)                   (WId @(D2AcE (Select env1 stoRepl "accum")))) -      (_, SAccum, _) -> -        k (storepl `DPush` (t, SAccum)) -          mergesub -          envpro -          (\shbinds -> _ (wf shbinds)) +      -- Used "merge" values must either _be_ an array (and hence be caught by +      -- the prior case), or contain _no_ arrays at all (TODO: generalise this) +      (_, SMerge, Occ _ c) | c > Zero, containsTArr t -> +        error $ "Closure variable of 'build'-like thing contains a composite type containing an array: " ++ show t -      _ -> _ +      -- What's left are normal "merge" values that don't contain arrays; those +      -- remain as-is +      (_, SMerge, _) -> +        k (storepl `DPush` (t, SMerge)) +          (SEYes mergesub) +          envpro +          wf +  where +    containsTArr :: STy t' -> Bool +    containsTArr = \case +      STNil -> False +      STPair a b -> containsTArr a || containsTArr b +      STEither a b -> containsTArr a || containsTArr b +      STArr{} -> True +      STScal{} -> False +      STAccum{} -> error "An accumulator in merge storage?"  -- | @env'@ is a subset of @env@: each element of @env@ is either included in  -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). @@ -501,10 +530,6 @@ subenvNone :: SList f env -> Subenv env '[]  subenvNone SNil = SETop  subenvNone (SCons _ env) = SENo (subenvNone env) -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) -  subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]  subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)  subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) @@ -592,7 +617,7 @@ rebaseRetPair :: forall env b1 b2 env0 sto t f.  rebaseRetPair descr b1 b2 (RetPair p sub d)    | Refl <- lemAppendAssoc @b2 @b1 @env =        RetPair p sub (weakenExpr (autoWeak -                                  (Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr))) +                                  (#d (auto @'[D2 t]) $.. #b2 b2 $.. #b1 b1 $.. #tl (d2ace (select SAccum descr)))                                    (#d :++: (#b2 :++: #tl))                                    (#d :++: ((#b2 :++: #b1) :++: #tl)))                                  d) @@ -734,10 +759,10 @@ drev des = \case          (weakenExpr wbody0' body1)          subBoth          (ELet ext -           (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] -                                  $.. Seg @"body" (bindingsBinds body0) -                                  $.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0)) -                                  $.. Seg @"tl" (d2ace (select SAccum des))) +           (weakenExpr (autoWeak (#d (auto @'[D2 t]) +                                  $.. #body (bindingsBinds body0) +                                  $.. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0)) +                                  $.. #tl (d2ace (select SAccum des)))                                   (#d :++: #body :++: #tl)                                   (#d :++: #body :++: #rhs :++: #tl))                         body2') $ @@ -842,12 +867,12 @@ drev des = \case                      ELet ext                        (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $                      ELet ext -                      (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] -                                             $.. Seg @"a0" (bindingsBinds a0) -                                             $.. Seg @"prea0" prerebinds -                                             $.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil) -                                             $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0) -                                             $.. Seg @"tl" (d2ace (select SAccum des))) +                      (weakenExpr (autoWeak (#d (auto @'[D2 t]) +                                             $.. #a0 (bindingsBinds a0) +                                             $.. #prea0 prerebinds +                                             $.. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) +                                             $.. #binds (tPrimal `SCons` bindingsBinds e0) +                                             $.. #tl (d2ace (select SAccum des)))                                              (#d :++: #a0 :++: #tl)                                              (#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl))                                    a2') $ @@ -861,12 +886,12 @@ drev des = \case                      ELet ext                        (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $                      ELet ext -                      (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] -                                             $.. Seg @"b0" (bindingsBinds b0) -                                             $.. Seg @"preb0" prerebinds -                                             $.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil) -                                             $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0) -                                             $.. Seg @"tl" (d2ace (select SAccum des))) +                      (weakenExpr (autoWeak (#d (auto @'[D2 t]) +                                             $.. #b0 (bindingsBinds b0) +                                             $.. #preb0 prerebinds +                                             $.. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) +                                             $.. #binds (tPrimal `SCons` bindingsBinds e0) +                                             $.. #tl (d2ace (select SAccum des)))                                              (#d :++: #b0 :++: #tl)                                              (#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl))                                    b2') $ @@ -920,10 +945,10 @@ drev des = \case      Ret (bconcat (ne0 `BPush` (tIx, ne1))                   (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))          (EBuild1 ext -           (weakenExpr (autoWeak (Seg @"ve0" (bindingsBinds ve0) -                                  $.. Seg' @"i" @'[TIx] -                                  $.. Seg @"ne0" (bindingsBinds ne0) -                                  $.. Seg @"tl" (sD1eEnv des)) +           (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) +                                  $.. #i (auto @'[TIx]) +                                  $.. #ne0 (bindingsBinds ne0) +                                  $.. #tl (sD1eEnv des))                                   (#ne0 :++: #tl)                                   ((#ve0 :++: #i :++: #ne0) :++: #tl))                         ne1) | 
