summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs22
-rw-r--r--src/AST/Weaken.hs129
-rw-r--r--src/AST/Weaken/Auto.hs151
-rw-r--r--src/CHAD.hs22
5 files changed, 175 insertions, 150 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ca3a2aa..19c2852 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -14,6 +14,7 @@ library
AST.Count
AST.Pretty
AST.Weaken
+ AST.Weaken.Auto
CHAD
-- Compile
Data
diff --git a/src/AST.hs b/src/AST.hs
index e33c10b..15e6d43 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -1,18 +1,18 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
module AST (module AST, module AST.Weaken) where
import Data.Functor.Const
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 42cdbd5..aa88c8e 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -1,17 +1,12 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PartialTypeSignatures #-}
@@ -25,8 +20,6 @@ module AST.Weaken (module AST.Weaken, Append) where
import Data.Bifunctor (first)
import Data.Functor.Const
import Data.Kind (Type)
-import GHC.TypeLits
-import Unsafe.Coerce
import Data
import Lemmas
@@ -101,125 +94,3 @@ wCopies (SCons _ spine) w = WCopy (wCopies spine w)
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
-
-
-type family Lookup name list where
- Lookup name ('(name, x) : _) = x
- Lookup name (_ : list) = Lookup name list
-
-data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where
- LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments)
- (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2)
-infixr :++:
-
-data SSegments (segments :: [(Symbol, [t])]) where
- SSegNil :: SSegments '[]
- SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
-
-class KnownSegments (segments :: [(Symbol, [t])]) where
- knownSegments :: SSegments segments
-
-instance KnownSegments '[] where
- knownSegments = SSegNil
-
-instance (KnownSymbol name, KnownListSpine ts, KnownSegments list)
- => KnownSegments ('(name, ts) : list) where
- knownSegments = SSegCons symbolSing knownListSpine knownSegments
-
-class ToSegments k a | a -> k where
- type SegmentsOf k a :: [(Symbol, [k])]
- toSegments :: a -> SSegments (SegmentsOf k a)
-
-instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where
- type SegmentsOf k (SSegments segments) = segments
- toSegments = id
-
-data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts)
- | (KnownSymbol name, KnownListSpine ts) => Seg'
-
-instance ToSegments k (GivenSegment name (ts :: [k])) where
- type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)]
- toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil
- toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil
-
-infixr $..
-($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b))
-x $.. y = ssegmentsAppend (toSegments x) (toSegments y)
- where
- ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
- ssegmentsAppend SSegNil l2 = l2
- ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2)
-
--- | If the found segment is a TopSeg, returns Nothing.
-segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
-segmentLookup = \segs name -> case go segs name of
- Just ts -> ts
- Nothing -> error $ "Segment not found: " ++ fromSSymbol name
- where
- go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs'))
- go SSegNil _ = Nothing
- go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol =
- case sameSymbol n name of
- Just Refl ->
- case go sseg name of
- Nothing -> Just ts
- Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name
- Nothing ->
- case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of
- Refl -> go sseg name
-
-data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where
- LinEnd :: LinLayout segments '[]
- LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env)
-
-linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2)
-linLayoutAppend LinEnd lin = lin
-linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2)
- | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2
- = LinApp name (linLayoutAppend lin1 lin2)
-
-lineariseLayout :: Layout segments env -> LinLayout segments env
-lineariseLayout (LSeg @name :: Layout _ seg)
- | Refl <- lemAppendNil @seg
- = LinApp (symbolSing @name) LinEnd
-lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
-
-linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env
-linLayoutEnv _ LinEnd = SNil
-linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin)
-
-autoWeak :: forall segments env1 env2.
- SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2
-autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2)
-
-pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env
- -> r -- Name was not found in source
- -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r)
- -> r
-pullDown segs name@SSymbol linlayout kNotFound k =
- case linlayout of
- LinEnd -> kNotFound
- LinApp n'@SSymbol lin
- | Just Refl <- sameSymbol name n' -> k lin WId
- | otherwise ->
- pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w ->
- k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin')
- .> wCopies (segmentLookup segs n') w)
-
-sortLinLayouts :: forall segments env1 env2.
- SSegments segments
- -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2
-sortLinLayouts _ LinEnd LinEnd = WId
-sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
- | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2)
- | otherwise =
- pullDown segs name2 lin1
- (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2)
- (\tail1' w ->
- -- We've pulled down name2 in lin1 so that it's at the head; the
- -- resulting modified tail is tail1'. Thus now we have (name2 : tail1')
- -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
- -- wCopies the name2 on top of that.
- wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
-sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source"
-sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target"
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
new file mode 100644
index 0000000..93116b8
--- /dev/null
+++ b/src/AST/Weaken/Auto.hs
@@ -0,0 +1,151 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE FunctionalDependencies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeAbstractions #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+{-# LANGUAGE AllowAmbiguousTypes #-}
+
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
+module AST.Weaken.Auto (
+ autoWeak,
+ GivenSegment(..),
+ ($..),
+ Layout(..),
+) where
+
+import Data.Functor.Const
+import GHC.OverloadedLabels
+import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
+
+import AST.Weaken
+import Data
+import Lemmas
+
+
+type family Lookup name list where
+ Lookup name ('(name, x) : _) = x
+ Lookup name (_ : list) = Lookup name list
+
+
+data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where
+ LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments)
+ (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2)
+infixr :++:
+
+instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout segments seg) where
+ fromLabel = LSeg @name @segments
+
+
+data SSegments (segments :: [(Symbol, [t])]) where
+ SSegNil :: SSegments '[]
+ SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
+
+class ToSegments k a | a -> k where
+ type SegmentsOf k a :: [(Symbol, [k])]
+ toSegments :: a -> SSegments (SegmentsOf k a)
+
+instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where
+ type SegmentsOf k (SSegments segments) = segments
+ toSegments = id
+
+data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts)
+ | (KnownSymbol name, KnownListSpine ts) => Seg'
+
+instance ToSegments k (GivenSegment name (ts :: [k])) where
+ type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)]
+ toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil
+ toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil
+
+infixr $..
+($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b))
+x $.. y = ssegmentsAppend (toSegments x) (toSegments y)
+ where
+ ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
+ ssegmentsAppend SSegNil l2 = l2
+ ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2)
+
+
+-- | If the found segment is a TopSeg, returns Nothing.
+segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
+segmentLookup = \segs name -> case go segs name of
+ Just ts -> ts
+ Nothing -> error $ "Segment not found: " ++ fromSSymbol name
+ where
+ go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs'))
+ go SSegNil _ = Nothing
+ go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol =
+ case sameSymbol n name of
+ Just Refl ->
+ case go sseg name of
+ Nothing -> Just ts
+ Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name
+ Nothing ->
+ case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of
+ Refl -> go sseg name
+
+data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where
+ LinEnd :: LinLayout segments '[]
+ LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env)
+
+linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2)
+linLayoutAppend LinEnd lin = lin
+linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2)
+ | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2
+ = LinApp name (linLayoutAppend lin1 lin2)
+
+lineariseLayout :: Layout segments env -> LinLayout segments env
+lineariseLayout (LSeg @name :: Layout _ seg)
+ | Refl <- lemAppendNil @seg
+ = LinApp (symbolSing @name) LinEnd
+lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
+
+linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env
+linLayoutEnv _ LinEnd = SNil
+linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin)
+
+pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env
+ -> r -- Name was not found in source
+ -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r)
+ -> r
+pullDown segs name@SSymbol linlayout kNotFound k =
+ case linlayout of
+ LinEnd -> kNotFound
+ LinApp n'@SSymbol lin
+ | Just Refl <- sameSymbol name n' -> k lin WId
+ | otherwise ->
+ pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w ->
+ k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin')
+ .> wCopies (segmentLookup segs n') w)
+
+sortLinLayouts :: forall segments env1 env2.
+ SSegments segments
+ -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2
+sortLinLayouts _ LinEnd LinEnd = WId
+sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
+ | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2)
+ | otherwise =
+ pullDown segs name2 lin1
+ (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2)
+ (\tail1' w ->
+ -- We've pulled down name2 in lin1 so that it's at the head; the
+ -- resulting modified tail is tail1'. Thus now we have (name2 : tail1')
+ -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
+ -- wCopies the name2 on top of that.
+ wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
+sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source"
+sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target"
+
+autoWeak :: forall segments env1 env2.
+ SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2
+autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 97632c7..e99859c 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
@@ -33,6 +34,7 @@ import GHC.TypeLits (Symbol)
import AST
import AST.Count
+import AST.Weaken.Auto
import Data
import Lemmas
@@ -591,8 +593,8 @@ rebaseRetPair descr b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
RetPair p sub (weakenExpr (autoWeak
(Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr)))
- (LSeg @"d" :++: (LSeg @"b2" :++: LSeg @"tl"))
- (LSeg @"d" :++: ((LSeg @"b2" :++: LSeg @"b1") :++: LSeg @"tl")))
+ (#d :++: (#b2 :++: #tl))
+ (#d :++: ((#b2 :++: #b1) :++: #tl)))
d)
retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
@@ -736,8 +738,8 @@ drev des = \case
$.. Seg @"body" (bindingsBinds body0)
$.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0))
$.. Seg @"tl" (d2ace (select SAccum des)))
- (LSeg @"d" :++: LSeg @"body" :++: LSeg @"tl")
- (LSeg @"d" :++: LSeg @"body" :++: LSeg @"rhs" :++: LSeg @"tl"))
+ (#d :++: #body :++: #tl)
+ (#d :++: #body :++: #rhs :++: #tl))
body2') $
ELet ext
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
@@ -846,8 +848,8 @@ drev des = \case
$.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
$.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)
$.. Seg @"tl" (d2ace (select SAccum des)))
- (LSeg @"d" :++: LSeg @"a0" :++: LSeg @"tl")
- (LSeg @"d" :++: (LSeg @"a0" :++: LSeg @"prea0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl"))
+ (#d :++: #a0 :++: #tl)
+ (#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2') $
EPair ext
(expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
@@ -865,8 +867,8 @@ drev des = \case
$.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
$.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)
$.. Seg @"tl" (d2ace (select SAccum des)))
- (LSeg @"d" :++: LSeg @"b0" :++: LSeg @"tl")
- (LSeg @"d" :++: (LSeg @"b0" :++: LSeg @"preb0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl"))
+ (#d :++: #b0 :++: #tl)
+ (#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2') $
EPair ext
(expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
@@ -922,8 +924,8 @@ drev des = \case
$.. Seg' @"i" @'[TIx]
$.. Seg @"ne0" (bindingsBinds ne0)
$.. Seg @"tl" (sD1eEnv des))
- (LSeg @"ne0" :++: LSeg @"tl")
- ((LSeg @"ve0" :++: LSeg @"i" :++: LSeg @"ne0") :++: LSeg @"tl"))
+ (#ne0 :++: #tl)
+ ((#ve0 :++: #i :++: #ne0) :++: #tl))
ne1)
(subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
Left ibind ->