diff options
| -rw-r--r-- | chad-fast.cabal | 1 | ||||
| -rw-r--r-- | src/AST.hs | 22 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 129 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 151 | ||||
| -rw-r--r-- | src/CHAD.hs | 22 | 
5 files changed, 175 insertions, 150 deletions
| diff --git a/chad-fast.cabal b/chad-fast.cabal index ca3a2aa..19c2852 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -14,6 +14,7 @@ library      AST.Count      AST.Pretty      AST.Weaken +    AST.Weaken.Auto      CHAD      -- Compile      Data @@ -1,18 +1,18 @@  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE DeriveFunctor #-}  {-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-}  {-# LANGUAGE DeriveTraversable #-}  {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-}  module AST (module AST, module AST.Weaken) where  import Data.Functor.Const diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 42cdbd5..aa88c8e 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -1,17 +1,12 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-}  {-# LANGUAGE GADTs #-} -{-# LANGUAGE MultiParamTypeClasses #-}  {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-}  {-# LANGUAGE TypeAbstractions #-}  {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE PartialTypeSignatures #-} @@ -25,8 +20,6 @@ module AST.Weaken (module AST.Weaken, Append) where  import Data.Bifunctor (first)  import Data.Functor.Const  import Data.Kind (Type) -import GHC.TypeLits -import Unsafe.Coerce  import Data  import Lemmas @@ -101,125 +94,3 @@ wCopies (SCons _ spine) w = WCopy (wCopies spine w)  wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env  wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)  wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) - - -type family Lookup name list where -  Lookup name ('(name, x) : _) = x -  Lookup name (_ : list) = Lookup name list - -data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where -  LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) -  (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) -infixr :++: - -data SSegments (segments :: [(Symbol, [t])]) where -  SSegNil :: SSegments '[] -  SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) - -class KnownSegments (segments :: [(Symbol, [t])]) where -  knownSegments :: SSegments segments - -instance KnownSegments '[] where -  knownSegments = SSegNil - -instance (KnownSymbol name, KnownListSpine ts, KnownSegments list) -      => KnownSegments ('(name, ts) : list) where -  knownSegments = SSegCons symbolSing knownListSpine knownSegments - -class ToSegments k a | a -> k where -  type SegmentsOf k a :: [(Symbol, [k])] -  toSegments :: a -> SSegments (SegmentsOf k a) - -instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where -  type SegmentsOf k (SSegments segments) = segments -  toSegments = id - -data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts) -                          | (KnownSymbol name, KnownListSpine ts) => Seg' - -instance ToSegments k (GivenSegment name (ts :: [k])) where -  type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)] -  toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil -  toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil - -infixr $.. -($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b)) -x $.. y = ssegmentsAppend (toSegments x) (toSegments y) -  where -    ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) -    ssegmentsAppend SSegNil l2 = l2 -    ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2) - --- | If the found segment is a TopSeg, returns Nothing. -segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) -segmentLookup = \segs name -> case go segs name of -                           Just ts -> ts -                           Nothing -> error $ "Segment not found: " ++ fromSSymbol name -  where -    go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) -    go SSegNil _ = Nothing -    go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = -      case sameSymbol n name of -        Just Refl -> -          case go sseg name of -            Nothing -> Just ts -            Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name -        Nothing -> -          case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of -            Refl -> go sseg name - -data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where -  LinEnd :: LinLayout segments '[] -  LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env) - -linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2) -linLayoutAppend LinEnd lin = lin -linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2) -  | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 -  = LinApp name (linLayoutAppend lin1 lin2) - -lineariseLayout :: Layout segments env -> LinLayout segments env -lineariseLayout (LSeg @name :: Layout _ seg) -  | Refl <- lemAppendNil @seg -  = LinApp (symbolSing @name) LinEnd -lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 - -linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env -linLayoutEnv _ LinEnd = SNil -linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) - -autoWeak :: forall segments env1 env2. -            SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2 -autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2) - -pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env -         -> r  -- Name was not found in source -         -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r) -         -> r -pullDown segs name@SSymbol linlayout kNotFound k = -  case linlayout of -    LinEnd -> kNotFound -    LinApp n'@SSymbol lin -      | Just Refl <- sameSymbol name n' -> k lin WId -      | otherwise -> -          pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w -> -            k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin') -                                  .> wCopies (segmentLookup segs n') w) - -sortLinLayouts :: forall segments env1 env2. -                  SSegments segments -               -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2 -sortLinLayouts _ LinEnd LinEnd = WId -sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) -  | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) -  | otherwise = -      pullDown segs name2 lin1 -        (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) -        (\tail1' w -> -          -- We've pulled down name2 in lin1 so that it's at the head; the -          -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') -          -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and -          -- wCopies the name2 on top of that. -          wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) -sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source" -sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target" diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs new file mode 100644 index 0000000..93116b8 --- /dev/null +++ b/src/AST/Weaken/Auto.hs @@ -0,0 +1,151 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +{-# LANGUAGE AllowAmbiguousTypes #-} + +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS_GHC -Wno-partial-type-signatures #-} +module AST.Weaken.Auto ( +  autoWeak, +  GivenSegment(..), +  ($..), +  Layout(..), +) where + +import Data.Functor.Const +import GHC.OverloadedLabels +import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) + +import AST.Weaken +import Data +import Lemmas + + +type family Lookup name list where +  Lookup name ('(name, x) : _) = x +  Lookup name (_ : list) = Lookup name list + + +data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where +  LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) +  (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) +infixr :++: + +instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout segments seg) where +  fromLabel = LSeg @name @segments + + +data SSegments (segments :: [(Symbol, [t])]) where +  SSegNil :: SSegments '[] +  SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) + +class ToSegments k a | a -> k where +  type SegmentsOf k a :: [(Symbol, [k])] +  toSegments :: a -> SSegments (SegmentsOf k a) + +instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where +  type SegmentsOf k (SSegments segments) = segments +  toSegments = id + +data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts) +                          | (KnownSymbol name, KnownListSpine ts) => Seg' + +instance ToSegments k (GivenSegment name (ts :: [k])) where +  type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)] +  toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil +  toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil + +infixr $.. +($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b)) +x $.. y = ssegmentsAppend (toSegments x) (toSegments y) +  where +    ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) +    ssegmentsAppend SSegNil l2 = l2 +    ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2) + + +-- | If the found segment is a TopSeg, returns Nothing. +segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) +segmentLookup = \segs name -> case go segs name of +                           Just ts -> ts +                           Nothing -> error $ "Segment not found: " ++ fromSSymbol name +  where +    go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) +    go SSegNil _ = Nothing +    go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = +      case sameSymbol n name of +        Just Refl -> +          case go sseg name of +            Nothing -> Just ts +            Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name +        Nothing -> +          case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of +            Refl -> go sseg name + +data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where +  LinEnd :: LinLayout segments '[] +  LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env) + +linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2) +linLayoutAppend LinEnd lin = lin +linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2) +  | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 +  = LinApp name (linLayoutAppend lin1 lin2) + +lineariseLayout :: Layout segments env -> LinLayout segments env +lineariseLayout (LSeg @name :: Layout _ seg) +  | Refl <- lemAppendNil @seg +  = LinApp (symbolSing @name) LinEnd +lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 + +linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env +linLayoutEnv _ LinEnd = SNil +linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) + +pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env +         -> r  -- Name was not found in source +         -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r) +         -> r +pullDown segs name@SSymbol linlayout kNotFound k = +  case linlayout of +    LinEnd -> kNotFound +    LinApp n'@SSymbol lin +      | Just Refl <- sameSymbol name n' -> k lin WId +      | otherwise -> +          pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w -> +            k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin') +                                  .> wCopies (segmentLookup segs n') w) + +sortLinLayouts :: forall segments env1 env2. +                  SSegments segments +               -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2 +sortLinLayouts _ LinEnd LinEnd = WId +sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) +  | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) +  | otherwise = +      pullDown segs name2 lin1 +        (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) +        (\tail1' w -> +          -- We've pulled down name2 in lin1 so that it's at the head; the +          -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') +          -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and +          -- wCopies the name2 on top of that. +          wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) +sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source" +sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target" + +autoWeak :: forall segments env1 env2. +            SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2 +autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2) diff --git a/src/CHAD.hs b/src/CHAD.hs index 97632c7..e99859c 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -2,6 +2,7 @@  {-# LANGUAGE EmptyCase #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-} @@ -33,6 +34,7 @@ import GHC.TypeLits (Symbol)  import AST  import AST.Count +import AST.Weaken.Auto  import Data  import Lemmas @@ -591,8 +593,8 @@ rebaseRetPair descr b1 b2 (RetPair p sub d)    | Refl <- lemAppendAssoc @b2 @b1 @env =        RetPair p sub (weakenExpr (autoWeak                                    (Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr))) -                                  (LSeg @"d" :++: (LSeg @"b2" :++: LSeg @"tl")) -                                  (LSeg @"d" :++: ((LSeg @"b2" :++: LSeg @"b1") :++: LSeg @"tl"))) +                                  (#d :++: (#b2 :++: #tl)) +                                  (#d :++: ((#b2 :++: #b1) :++: #tl)))                                  d)  retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list @@ -736,8 +738,8 @@ drev des = \case                                    $.. Seg @"body" (bindingsBinds body0)                                    $.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0))                                    $.. Seg @"tl" (d2ace (select SAccum des))) -                                 (LSeg @"d" :++: LSeg @"body" :++: LSeg @"tl") -                                 (LSeg @"d" :++: LSeg @"body" :++: LSeg @"rhs" :++: LSeg @"tl")) +                                 (#d :++: #body :++: #tl) +                                 (#d :++: #body :++: #rhs :++: #tl))                         body2') $           ELet ext             (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ @@ -846,8 +848,8 @@ drev des = \case                                               $.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil)                                               $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)                                               $.. Seg @"tl" (d2ace (select SAccum des))) -                                            (LSeg @"d" :++: LSeg @"a0" :++: LSeg @"tl") -                                            (LSeg @"d" :++: (LSeg @"a0" :++: LSeg @"prea0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl")) +                                            (#d :++: #a0 :++: #tl) +                                            (#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl))                                    a2') $                      EPair ext                       (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ @@ -865,8 +867,8 @@ drev des = \case                                               $.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil)                                               $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)                                               $.. Seg @"tl" (d2ace (select SAccum des))) -                                            (LSeg @"d" :++: LSeg @"b0" :++: LSeg @"tl") -                                            (LSeg @"d" :++: (LSeg @"b0" :++: LSeg @"preb0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl")) +                                            (#d :++: #b0 :++: #tl) +                                            (#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl))                                    b2') $                      EPair ext                        (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ @@ -922,8 +924,8 @@ drev des = \case                                    $.. Seg' @"i" @'[TIx]                                    $.. Seg @"ne0" (bindingsBinds ne0)                                    $.. Seg @"tl" (sD1eEnv des)) -                                 (LSeg @"ne0" :++: LSeg @"tl") -                                 ((LSeg @"ve0" :++: LSeg @"i" :++: LSeg @"ne0") :++: LSeg @"tl")) +                                 (#ne0 :++: #tl) +                                 ((#ve0 :++: #i :++: #ne0) :++: #tl))                         ne1)             (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of                                 Left ibind -> | 
