{-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -- The reason why this is a separate module with "little" in it: {-# LANGUAGE AllowAmbiguousTypes #-} module AST.Weaken (module AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const import Data.Kind (Type) import GHC.TypeLits import Unsafe.Coerce import Data import Lemmas type Idx :: [k] -> k -> Type data Idx env t where IZ :: Idx (t : env) t IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) data env :> env' where WId :: env :> env WSink :: forall t env. env :> (t : env) WCopy :: env :> env' -> (t : env) :> (t : env') WPop :: (t : env) :> env' -> env :> env' WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 WClosed :: SList (Const ()) env -> '[] :> env WIdx :: Idx env t -> (t : env) :> env WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' -> Append pre (t : env) :> t : Append pre env' WSwap :: SList (Const ()) as -> SList (Const ()) bs -> SList (Const ()) env -> Append as (Append bs env) :> Append bs (Append as env) deriving instance Show (env :> env') infix 4 :> infixr 2 @> (@>) :: env :> env' -> Idx env t -> Idx env' t WId @> i = i WSink @> i = IS i WCopy _ @> IZ = IZ WCopy w @> IS i = IS (w @> i) WPop w @> i = w @> IS i WThen w1 w2 @> i = w2 @> w1 @> i WClosed _ @> i = case i of {} WIdx j @> IZ = j WIdx _ @> IS i = i WPick SNil w @> i = WCopy w @> i WPick (_ `SCons` _) _ @> IZ = IS IZ WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i WSwap (as :: SList _ as) (bs :: SList _ bs) (env :: SList _ env) @> i = case splitIdx @(Append bs env) as i of Left i' -> wSinks bs .> wRaiseAbove as env @> i' Right j -> case splitIdx @env bs j of Left j' -> wRaiseAbove bs (sappend as env) @> j' Right k -> wSinks bs .> wSinks as @> k infixr 3 .> (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 (.>) = flip WThen class KnownListSpine list where knownListSpine :: SList (Const ()) list instance KnownListSpine '[] where knownListSpine = SNil instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine wSinks' :: forall list env. KnownListSpine list => env :> Append list env wSinks' = wSinks (knownListSpine :: SList (Const ()) list) wSinks :: forall env bs f. SList f bs -> env :> Append bs env wSinks SNil = WId wSinks (SCons _ spine) = WSink .> wSinks spine wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 wCopies SNil w = w wCopies (SCons _ spine) w = WCopy (wCopies spine w) wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env wStack WId = WId wStack WSink = WSink wStack (WCopy w) = WCopy (wStack @env w) wStack (WPop w) = WPop (wStack @env w) wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) wStack (WClosed s) = wSinks s wStack (WIdx i) = WIdx (goIdx i) where goIdx :: Idx b t -> Idx (Append b env) t goIdx IZ = IZ goIdx (IS i') = IS (goIdx i') wStack (WPick @t @_ @env1 @env2 (pre :: SList (Const ()) pre) w) | Refl <- lemAppendAssoc @pre @env2 @env , Refl <- lemAppendAssoc @pre @(t : env1) @env = WPick @t @_ pre (wStack @env w) wStack WSwap{} = error "OOPS" wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) type family Lookup name list where Lookup name ('(name, x) : _) = x Lookup name (_ : list) = Lookup name list data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) class KnownSegments (segments :: [(Symbol, [t])]) where knownSegments :: SSegments segments instance KnownSegments '[] where knownSegments = SSegNil instance (KnownSymbol name, KnownListSpine ts, KnownSegments list) => KnownSegments ('(name, ts) : list) where knownSegments = SSegCons symbolSing knownListSpine knownSegments segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) segmentLookup = \segs name -> case go segs name of Just ts -> ts Nothing -> error $ "Segment not found: " ++ fromSSymbol name where go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) go SSegNil _ = Nothing go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = case sameSymbol n name of Just Refl -> case go sseg name of Nothing -> Just ts Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name Nothing -> case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of Refl -> go sseg name data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where LinEnd :: LinLayout segments '[] LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env) linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2) linLayoutAppend LinEnd lin = lin linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2) | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 = LinApp name (linLayoutAppend lin1 lin2) lineariseLayout :: Layout segments env -> LinLayout segments env lineariseLayout (LSeg @name :: Layout _ seg) | Refl <- lemAppendNil @seg = LinApp (symbolSing @name) LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env linLayoutEnv _ LinEnd = SNil linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) autoWeak :: forall segments env1 env2. KnownSegments segments => Layout segments env1 -> Layout segments env2 -> env1 :> env2 autoWeak ly1 ly2 = sortLinLayouts knownSegments (lineariseLayout ly1) (lineariseLayout ly2) pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env -> r -- Name was not found in source -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r) -> r pullDown segs name@SSymbol linlayout kNotFound k = case linlayout of LinEnd -> kNotFound LinApp n'@SSymbol lin | Just Refl <- sameSymbol name n' -> k lin WId | otherwise -> pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w -> k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin') .> wCopies (segmentLookup segs n') w) sortLinLayouts :: forall segments env1 env2. SSegments segments -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) | otherwise = pullDown segs name2 lin1 (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) (\tail1' w -> -- We've pulled down name2 in lin1 so that it's at the head; the -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and -- wCopies the name2 on top of that. wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source" sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target"