diff options
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 172 | 
1 files changed, 102 insertions, 70 deletions
| diff --git a/src/CHAD.hs b/src/CHAD.hs index 6a4d5f5..8db0410 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -36,7 +36,7 @@ import Data.Type.Bool (If)  import Data.Type.Equality (type (==))  import GHC.Stack (HasCallStack) -import Analysis.Identity (ValId(..)) +import Analysis.Identity (ValId(..), validSplitEither)  import AST  import AST.Bindings  import AST.Count @@ -294,18 +294,18 @@ data Idx2 env sto t    | Idx2Di (Idx (Select env sto "discr") t)  conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _   (_, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _   (_, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _   (_, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = +conv2Idx (DPush _   (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _   (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _   (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) =    case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)                           Idx2Me j -> Idx2Me j                           Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, SMerge)) (IS i) = +conv2Idx (DPush des (_, _, SMerge)) (IS i) =    case conv2Idx des i of Idx2Ac j -> Idx2Ac j                           Idx2Me j -> Idx2Me (IS j)                           Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, SDiscr)) (IS i) = +conv2Idx (DPush des (_, _, SDiscr)) (IS i) =    case conv2Idx des i of Idx2Ac j -> Idx2Ac j                           Idx2Me j -> Idx2Me j                           Idx2Di j -> Idx2Di (IS j) @@ -376,6 +376,10 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"  --------------------------------- ACCUMULATORS --------------------------------- +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing +  accumPromote :: forall dt env sto proxy r.                  proxy dt               -> Descr env sto @@ -384,8 +388,7 @@ accumPromote :: forall dt env sto proxy r.                   => Descr env stoRepl                        -- ^ A revised environment description that switches                        -- arrays (used in the OccEnv) that are currently on -                      -- "merge" storage, to "accum" storage. Any other "merge" -                      -- entries are deleted. +                      -- "merge" storage, to "accum" storage.                   -> SList STy envPro                        -- ^ New entries on top of the original dual environment,                        -- that house the accumulators for the promoted arrays in @@ -393,6 +396,12 @@ accumPromote :: forall dt env sto proxy r.                   -> Subenv (Select env sto "merge") envPro                        -- ^ The promoted entries were merge entries in the                        -- original environment. +                 -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) +                      -- ^ All entries that were accumulators are still +                      -- accumulators. +                 -> VarMap Int (D2AcE (Select env stoRepl "accum")) +                      -- ^ Accumulator map for _only_ the the newly allocated +                      -- accumulators.                   -> (forall shbinds.                              SList STy shbinds                           -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) @@ -402,57 +411,70 @@ accumPromote :: forall dt env sto proxy r.                        -- extended with some accumulators.                   -> r)               -> r -accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -  accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf -> -    case sto of -      -- Accumulators are left as-is -      SAccum -> -        k (storepl `DPush` (t, SAccum)) -          envpro -          prosub -          (\shbinds -> -            autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) -                     (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) -                     (#pro :++: #d :++: #shb :++: #acc :++: #tl) -            .> WCopy (wf shbinds) -            .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) -                        (#d :++: #shb :++: #acc :++: #tl) -                        (#acc :++: (#d :++: #shb :++: #tl))) - -      SMerge -> case t of -        -- Discrete values are left as-is -        _ | isDiscrete t -> -          k (storepl `DPush` (t, SDiscr)) -            envpro -            (SENo prosub) -            wf - -        -- Values with "merge" storage are promoted to an accumulator in envPro -        _ -> -          k (storepl `DPush` (t, SAccum)) -            (t `SCons` envpro) -            (SEYes prosub) -            (\(shbinds :: SList _ shbinds) -> -              let shbindsC = slistMap (\_ -> Const ()) shbinds -              in -              -- wf: -              --                 D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>                Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -              -- WCopy wf: -              --   TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -              --                       WPICK: ^                                                                 THESE TWO  || -              -- goal:                        |                                                                 ARE EQUAL  || -              --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -              WCopy (wf shbinds) -              .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) -                   (WId @(D2AcE (Select env1 stoRepl "accum")))) +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of +  -- Accumulators are left as-is +  SAccum -> +    accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> +      k (storepl `DPush` (t, vid, SAccum)) +        envpro +        prosub +        (SEYes accrevsub) +        (VarMap.sink1 accumMap) +        (\shbinds -> +          autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) +                   (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) +                   (#pro :++: #d :++: #shb :++: #acc :++: #tl) +          .> WCopy (wf shbinds) +          .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) +                      (#d :++: #shb :++: #acc :++: #tl) +                      (#acc :++: (#d :++: #shb :++: #tl))) -      -- Discrete values are left as-is, nothing to do -      SDiscr -> -        k (storepl `DPush` (t, SDiscr)) +  SMerge -> case t of +    -- Discrete values are left as-is +    _ | isDiscrete t -> +      accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> +        k (storepl `DPush` (t, vid, SDiscr))            envpro -          prosub +          (SENo prosub) +          accrevsub +          accumMap'            wf + +    -- Values with "merge" storage are promoted to an accumulator in envPro +    _ -> +      accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> +        k (storepl `DPush` (t, vid, SAccum)) +          (t `SCons` envpro) +          (SEYes prosub) +          (SENo accrevsub) +          (let accumMap' = VarMap.sink1 accumMap +           in case fromArrayValId vid of +                Just i -> VarMap.insert i (STAccum t) IZ accumMap' +                Nothing -> accumMap') +          (\(shbinds :: SList _ shbinds) -> +            let shbindsC = slistMap (\_ -> Const ()) shbinds +            in +            -- wf: +            --                 D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>                Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) +            -- WCopy wf: +            --   TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) +            --                       WPICK: ^                                                                 THESE TWO  || +            -- goal:                        |                                                                 ARE EQUAL  || +            --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) +            WCopy (wf shbinds) +            .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) +                 (WId @(D2AcE (Select env1 stoRepl "accum")))) + +  -- Discrete values are left as-is, nothing to do +  SDiscr -> +    accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> +      k (storepl `DPush` (t, vid, SDiscr)) +        envpro +        prosub +        accrevsub +        accumMap +        wf    where      isDiscrete :: STy t' -> Bool      isDiscrete = \case @@ -561,7 +583,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =  drev :: forall env sto t.          (?config :: CHADConfig) -     => Descr env sto -> VarMap Int env +     => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))       -> Expr ValId env t -> Ret env sto t  drev des accumMap = \case    EVar _ t i -> @@ -590,7 +612,7 @@ drev des accumMap = \case    ELet _ (rhs :: Expr _ _ a) body      | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs      , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge -    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body +    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body      , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0      , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)      , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> @@ -687,8 +709,9 @@ drev des accumMap = \case      , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e      , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge      , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge -    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a -    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b +    , let (bindids1, bindids2) = validSplitEither (extOf e) +    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a +    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b      , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) @@ -819,8 +842,9 @@ drev des accumMap = \case      deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->      let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in      subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> -    accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> -    case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> +    accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> +    let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in +    case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->      case assertSubenvEmpty sub of { Refl ->      let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in      let collectexpr = bindingsCollect e0 subtapeE in @@ -1055,17 +1079,22 @@ deriving instance Show (RetScoped env0 sto a s t)  drevScoped :: forall a s env sto t.                (?config :: CHADConfig) -           => Descr env sto -> VarMap Int env -> STy a -> Storage s +           => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) +           -> STy a -> Storage s -> Maybe (ValId a)             -> Expr ValId (a : env) t             -> RetScoped env sto a s t -drevScoped des accumMap argty argsto expr -  | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr -  = case argsto of -      SMerge -> +drevScoped des accumMap argty argsto argids expr = case argsto of +  SMerge +    | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->          case sub of            SEYes sub' -> RetScoped e0 subtape e1 sub' e2            SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) -      SAccum -> + +  SAccum +    | let accumMap' = case argids of +                        Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) +                        _ -> VarMap.sink1 accumMap +    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr ->          RetScoped e0 subtape e1 sub $            EWith ext argty (EZero ext argty) $              weakenExpr (autoWeak (#d (auto1 @(D2 t)) @@ -1075,4 +1104,7 @@ drevScoped des accumMap argty argsto expr                                   (#d :++: #body :++: #ac :++: #tl)                                   (#ac :++: #d :++: #body :++: #tl))                         e2 -      SDiscr -> RetScoped e0 subtape e1 sub e2 + +  SDiscr +    | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> +        RetScoped e0 subtape e1 sub e2 | 
