diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-09-02 17:49:54 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-09-02 17:50:12 +0200 |
commit | 7d44dcc2ca2c5c16e1ab4737ef6b2877214767ed (patch) | |
tree | 42e8b9292403f9ce3a6f04a15ebd62a766880339 | |
parent | 1f7ed2ee02222108684cfde8078e7a182f734a61 (diff) |
WIP autoWeak
-rw-r--r-- | chad-fast.cabal | 2 | ||||
-rw-r--r-- | src/AST.hs | 5 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 139 | ||||
-rw-r--r-- | src/CHAD.hs | 82 | ||||
-rw-r--r-- | src/Data.hs | 6 | ||||
-rw-r--r-- | src/Lemmas.hs | 7 |
6 files changed, 218 insertions, 23 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 2e7ee22..ca3a2aa 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -23,7 +23,7 @@ library Simplify other-modules: build-depends: - base >= 4.14 && < 4.19, + base >= 4.19 && < 4.21, containers, template-haskell, transformers, @@ -261,11 +261,6 @@ idx2int :: Idx env t -> Int idx2int IZ = 0 idx2int (IS n) = 1 + idx2int n -splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) -splitIdx SNil i = Right i -splitIdx (SCons _ _) IZ = Left IZ -splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) - class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 78577ee..6c66b07 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -1,23 +1,34 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} + -- The reason why this is a separate module with "little" in it: {-# LANGUAGE AllowAmbiguousTypes #-} -module AST.Weaken where +module AST.Weaken (module AST.Weaken, Append) where +import Data.Bifunctor (first) import Data.Functor.Const import Data.Kind (Type) +import GHC.TypeLits +import Unsafe.Coerce import Data +import Lemmas type Idx :: [k] -> k -> Type @@ -26,6 +37,11 @@ data Idx env t where IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) +splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) +splitIdx SNil i = Right i +splitIdx (SCons _ _) IZ = Left IZ +splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) + data env :> env' where WId :: env :> env WSink :: forall t env. env :> (t : env) @@ -34,28 +50,38 @@ data env :> env' where WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 WClosed :: SList (Const ()) env -> '[] :> env WIdx :: Idx env t -> (t : env) :> env + WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' + -> Append pre (t : env) :> t : Append pre env' + WSwap :: SList (Const ()) as -> SList (Const ()) bs -> SList (Const ()) env + -> Append as (Append bs env) :> Append bs (Append as env) deriving instance Show (env :> env') +infix 4 :> infixr 2 @> (@>) :: env :> env' -> Idx env t -> Idx env' t WId @> i = i WSink @> i = IS i WCopy _ @> IZ = IZ -WCopy w @> (IS i) = IS (w @> i) +WCopy w @> IS i = IS (w @> i) WPop w @> i = w @> IS i WThen w1 w2 @> i = w2 @> w1 @> i WClosed _ @> i = case i of {} WIdx j @> IZ = j WIdx _ @> IS i = i +WPick SNil w @> i = WCopy w @> i +WPick (_ `SCons` _) _ @> IZ = IS IZ +WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i +WSwap (as :: SList _ as) (bs :: SList _ bs) (env :: SList _ env) @> i = + case splitIdx @(Append bs env) as i of + Left i' -> wSinks bs .> wRaiseAbove as env @> i' + Right j -> case splitIdx @env bs j of + Left j' -> wRaiseAbove bs (sappend as env) @> j' + Right k -> wSinks bs .> wSinks as @> k infixr 3 .> (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 (.>) = flip WThen -type family Append a b where - Append '[] l = l - Append (x : xs) l = x : Append xs l - class KnownListSpine list where knownListSpine :: SList (Const ()) list instance KnownListSpine '[] where knownListSpine = SNil instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine @@ -83,7 +109,108 @@ wStack (WIdx i) = WIdx (goIdx i) goIdx :: Idx b t -> Idx (Append b env) t goIdx IZ = IZ goIdx (IS i') = IS (goIdx i') +wStack (WPick @t @_ @env1 @env2 (pre :: SList (Const ()) pre) w) + | Refl <- lemAppendAssoc @pre @env2 @env + , Refl <- lemAppendAssoc @pre @(t : env1) @env + = WPick @t @_ pre (wStack @env w) +wStack WSwap{} = error "OOPS" wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) + + +type family Lookup name list where + Lookup name ('(name, x) : _) = x + Lookup name (_ : list) = Lookup name list + +data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where + LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) + (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) + +data SSegments (segments :: [(Symbol, [t])]) where + SSegNil :: SSegments '[] + SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) + +class KnownSegments (segments :: [(Symbol, [t])]) where + knownSegments :: SSegments segments + +instance KnownSegments '[] where + knownSegments = SSegNil + +instance (KnownSymbol name, KnownListSpine ts, KnownSegments list) + => KnownSegments ('(name, ts) : list) where + knownSegments = SSegCons symbolSing knownListSpine knownSegments + +segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) +segmentLookup = \segs name -> case go segs name of + Just ts -> ts + Nothing -> error $ "Segment not found: " ++ fromSSymbol name + where + go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) + go SSegNil _ = Nothing + go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = + case sameSymbol n name of + Just Refl -> + case go sseg name of + Nothing -> Just ts + Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name + Nothing -> + case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of + Refl -> go sseg name + +data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where + LinEnd :: LinLayout segments '[] + LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env) + +linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2) +linLayoutAppend LinEnd lin = lin +linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2) + | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 + = LinApp name (linLayoutAppend lin1 lin2) + +lineariseLayout :: Layout segments env -> LinLayout segments env +lineariseLayout (LSeg @name :: Layout _ seg) + | Refl <- lemAppendNil @seg + = LinApp (symbolSing @name) LinEnd +lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 + +linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env +linLayoutEnv _ LinEnd = SNil +linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) + +autoWeak :: forall segments env1 env2. KnownSegments segments + => Layout segments env1 -> Layout segments env2 -> env1 :> env2 +autoWeak ly1 ly2 = sortLinLayouts knownSegments (lineariseLayout ly1) (lineariseLayout ly2) + +pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env + -> r -- Name was not found in source + -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r) + -> r +pullDown segs name@SSymbol linlayout kNotFound k = + case linlayout of + LinEnd -> kNotFound + LinApp n'@SSymbol lin + | Just Refl <- sameSymbol name n' -> k lin WId + | otherwise -> + pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w -> + k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin') + .> wCopies (segmentLookup segs n') w) + +sortLinLayouts :: forall segments env1 env2. + SSegments segments + -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2 +sortLinLayouts _ LinEnd LinEnd = WId +sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) + | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) + | otherwise = + pullDown segs name2 lin1 + (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) + (\tail1' w -> + -- We've pulled down name2 in lin1 so that it's at the head; the + -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') + -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and + -- wCopies the name2 on top of that. + wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) +sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source" +sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target" diff --git a/src/CHAD.hs b/src/CHAD.hs index a6dd9ff..45d2d08 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -27,10 +27,12 @@ module CHAD ( ) where import Data.Bifunctor (first, second) +import Data.Functor.Const import Data.Kind (Type) import GHC.TypeLits (Symbol) import AST +import AST.Count import Data import Lemmas @@ -142,10 +144,6 @@ sreverse :: SList f ts -> SList f (Reverse ts) sreverse SNil = SNil sreverse (SCons t ts) = ssnoc (sreverse ts) t -sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) -sappend SNil l = l -sappend (SCons x xs) l = SCons x (sappend xs l) - stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) stapeUnfoldings SNil = SNil stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) @@ -408,6 +406,63 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) +accumPromote :: forall t env sto proxy r. + proxy t + -> Descr env sto + -> OccEnv env + -> (forall stoRepl envPro. + Descr env stoRepl + -- ^ A revised environment description that switches + -- arrays (used in the OccEnv) that are currently on + -- "merge" storage, to "accum" storage. + -> Subenv (Select env sto "merge") (Select env stoRepl "merge") + -- ^ The new storage has fewer "merge"-storage entries. + -> SList STy envPro + -- ^ New entries on top of the original dual environment, + -- that house the accumulators for the promoted arrays in + -- the original environment. + -> (forall shbinds. + SList STy shbinds + -> (D2 t : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append envPro (D2 t : Append shbinds (D2AcE (Select env sto "accum")))) + -- ^ A weakening that converts a computation in the + -- revised environment to one in the original environment + -- extended with some accumulators. + -> r) + -> r +accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId) +accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId) +accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k = + accumPromote pty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> + case (t, sto, occ) of + (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero -> + k (storepl `DPush` (t, SAccum)) + (SENo mergesub) + (STAccum arrn arrt `SCons` envpro) + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum arrn arrt) @(D2 t : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + (_, SAccum, _) -> + k (storepl `DPush` (t, SAccum)) + mergesub + envpro + (\shbinds -> _ (wf shbinds)) + + _ -> _ + +-- | @env'@ is a subset of @env@: each element of @env@ is either included in +-- @env'@ ('SEYes') or not included in @env'@ ('SENo'). data Subenv env env' where SETop :: Subenv '[] '[] SEYes :: Subenv env env' -> Subenv (t : env) (t : env') @@ -440,11 +495,15 @@ subList SNil SETop = SNil subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub -subenvNone :: SList STy env -> Subenv env '[] +subenvNone :: SList f env -> Subenv env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) -subenvOnehot :: SList STy env -> Idx env t -> Subenv env '[t] +subenvAll :: SList f env -> Subenv env env +subenvAll SNil = SETop +subenvAll (SCons _ env) = SEYes (subenvAll env) + +subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) subenvOnehot SNil i = case i of {} @@ -525,10 +584,15 @@ weakenRets w (Rets binds list) = let (binds', _) = weakenBindings weakenExpr w binds in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list) -rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t -rebaseRetPair b1 b2 (RetPair p sub d) +rebaseRetPair :: forall env b1 b2 env0 sto t f. + SList f env0 -> SList f b1 -> SList f b2 + -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t +rebaseRetPair env b1 b2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (WCopy (wStack @(D2AcE (Select env0 sto "accum")) (wRaiseAbove b2 b1))) d) + RetPair p sub (weakenExpr (autoWeak @['("d", '[D2 t]), '("b2", b2), '("b1", b1), '("tl", D2AcE (Select env0 sto "accum"))] + (LSeg @"d" :++: (LSeg @"b2" :++: LSeg @"tl")) + (LSeg @"d" :++: ((LSeg @"b2" :++: LSeg @"b1") :++: LSeg @"tl"))) + d) retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list retConcat SNil = Rets BTop SNil diff --git a/src/Data.hs b/src/Data.hs index c3381d5..a3f4c3c 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -8,6 +8,8 @@ {-# LANGUAGE TypeOperators #-} module Data where +import Lemmas (Append) + data SList f l where SNil :: SList f '[] @@ -19,6 +21,10 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list slistMap _ SNil = SNil slistMap f (SCons x list) = SCons (f x) (slistMap f list) +sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) +sappend SNil l = l +sappend (SCons x xs) l = SCons x (sappend xs l) + data Nat = Z | S Nat deriving (Show, Eq, Ord) diff --git a/src/Lemmas.hs b/src/Lemmas.hs index cb62155..31a43ed 100644 --- a/src/Lemmas.hs +++ b/src/Lemmas.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE AllowAmbiguousTypes #-} module Lemmas (module Lemmas, (:~:)(Refl)) where @@ -8,8 +9,10 @@ module Lemmas (module Lemmas, (:~:)(Refl)) where import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) -import AST.Weaken +type family Append a b where + Append '[] l = l + Append (x : xs) l = x : Append xs l lemAppendNil :: Append a '[] :~: a lemAppendNil = unsafeCoerce Refl |