diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-09-02 20:39:03 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-02 20:39:03 +0200 |
commit | 625c2c28d49dbdceb8864554acdfe1776d5333e0 (patch) | |
tree | 8449edb529017252cb08257059387306595c8996 | |
parent | 7d44dcc2ca2c5c16e1ab4737ef6b2877214767ed (diff) |
Autoweak!
-rw-r--r-- | src/AST.hs | 1 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 51 | ||||
-rw-r--r-- | src/CHAD.hs | 53 |
3 files changed, 68 insertions, 37 deletions
@@ -15,7 +15,6 @@ {-# LANGUAGE EmptyCase #-} module AST (module AST, module AST.Weaken) where -import Data.Bifunctor (first) import Data.Functor.Const import Data.Kind (Type) import Data.Int diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 6c66b07..42cdbd5 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} @@ -97,24 +98,6 @@ wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 wCopies SNil w = w wCopies (SCons _ spine) w = WCopy (wCopies spine w) -wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env -wStack WId = WId -wStack WSink = WSink -wStack (WCopy w) = WCopy (wStack @env w) -wStack (WPop w) = WPop (wStack @env w) -wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) -wStack (WClosed s) = wSinks s -wStack (WIdx i) = WIdx (goIdx i) - where - goIdx :: Idx b t -> Idx (Append b env) t - goIdx IZ = IZ - goIdx (IS i') = IS (goIdx i') -wStack (WPick @t @_ @env1 @env2 (pre :: SList (Const ()) pre) w) - | Refl <- lemAppendAssoc @pre @env2 @env - , Refl <- lemAppendAssoc @pre @(t : env1) @env - = WPick @t @_ pre (wStack @env w) -wStack WSwap{} = error "OOPS" - wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) @@ -127,6 +110,7 @@ type family Lookup name list where data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) +infixr :++: data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] @@ -142,6 +126,31 @@ instance (KnownSymbol name, KnownListSpine ts, KnownSegments list) => KnownSegments ('(name, ts) : list) where knownSegments = SSegCons symbolSing knownListSpine knownSegments +class ToSegments k a | a -> k where + type SegmentsOf k a :: [(Symbol, [k])] + toSegments :: a -> SSegments (SegmentsOf k a) + +instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where + type SegmentsOf k (SSegments segments) = segments + toSegments = id + +data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts) + | (KnownSymbol name, KnownListSpine ts) => Seg' + +instance ToSegments k (GivenSegment name (ts :: [k])) where + type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)] + toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil + toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil + +infixr $.. +($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b)) +x $.. y = ssegmentsAppend (toSegments x) (toSegments y) + where + ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) + ssegmentsAppend SSegNil l2 = l2 + ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2) + +-- | If the found segment is a TopSeg, returns Nothing. segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) segmentLookup = \segs name -> case go segs name of Just ts -> ts @@ -179,9 +188,9 @@ linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) linLayoutEnv _ LinEnd = SNil linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) -autoWeak :: forall segments env1 env2. KnownSegments segments - => Layout segments env1 -> Layout segments env2 -> env1 :> env2 -autoWeak ly1 ly2 = sortLinLayouts knownSegments (lineariseLayout ly1) (lineariseLayout ly2) +autoWeak :: forall segments env1 env2. + SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2 +autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2) pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env -> r -- Name was not found in source diff --git a/src/CHAD.hs b/src/CHAD.hs index 45d2d08..97632c7 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -585,19 +585,20 @@ weakenRets w (Rets binds list) = in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list) rebaseRetPair :: forall env b1 b2 env0 sto t f. - SList f env0 -> SList f b1 -> SList f b2 + Descr env0 sto -> SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t -rebaseRetPair env b1 b2 (RetPair p sub d) +rebaseRetPair descr b1 b2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak @['("d", '[D2 t]), '("b2", b2), '("b1", b1), '("tl", D2AcE (Select env0 sto "accum"))] + RetPair p sub (weakenExpr (autoWeak + (Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr))) (LSeg @"d" :++: (LSeg @"b2" :++: LSeg @"tl")) (LSeg @"d" :++: ((LSeg @"b2" :++: LSeg @"b1") :++: LSeg @"tl"))) d) -retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat SNil = Rets BTop SNil -retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) - | Rets binds1 pairs1 <- retConcat list +retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat _ SNil = Rets BTop SNil +retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) + | Rets binds1 pairs1 <- retConcat descr list , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1) , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0) , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum")) @@ -605,7 +606,7 @@ retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) sub (weakenExpr (WCopy (sinkWithBindings binds)) d)) - (slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs)) + (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)) pairs)) d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) d1op (OAdd t) e = EOp ext (OAdd t) e @@ -731,7 +732,12 @@ drev des = \case (weakenExpr wbody0' body1) subBoth (ELet ext - (weakenExpr (WCopy (wStack @(D2AcE (Select env sto "accum")) (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0))))) + (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] + $.. Seg @"body" (bindingsBinds body0) + $.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0)) + $.. Seg @"tl" (d2ace (select SAccum des))) + (LSeg @"d" :++: LSeg @"body" :++: LSeg @"tl") + (LSeg @"d" :++: LSeg @"body" :++: LSeg @"rhs" :++: LSeg @"tl")) body2') $ ELet ext (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ @@ -742,7 +748,7 @@ drev des = \case EPair _ a b | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat $ drev des a `SCons` drev des b `SCons` SNil + <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> Ret binds @@ -834,8 +840,14 @@ drev des = \case ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ ELet ext - (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $ - WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) + (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] + $.. Seg @"a0" (bindingsBinds a0) + $.. Seg @"prea0" prerebinds + $.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil) + $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0) + $.. Seg @"tl" (d2ace (select SAccum des))) + (LSeg @"d" :++: LSeg @"a0" :++: LSeg @"tl") + (LSeg @"d" :++: (LSeg @"a0" :++: LSeg @"prea0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl")) a2') $ EPair ext (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ @@ -847,8 +859,14 @@ drev des = \case ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ ELet ext - (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $ - WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) + (weakenExpr (autoWeak (Seg' @"d" @'[D2 t] + $.. Seg @"b0" (bindingsBinds b0) + $.. Seg @"preb0" prerebinds + $.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil) + $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0) + $.. Seg @"tl" (d2ace (select SAccum des))) + (LSeg @"d" :++: LSeg @"b0" :++: LSeg @"tl") + (LSeg @"d" :++: (LSeg @"b0" :++: LSeg @"preb0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl")) b2') $ EPair ext (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ @@ -900,7 +918,12 @@ drev des = \case Ret (bconcat (ne0 `BPush` (tIx, ne1)) (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) (EBuild1 ext - (weakenExpr (wStack @(D1E env) (wSinks (bindingsBinds ve0) .> WSink @TIx @ne_binds)) + (weakenExpr (autoWeak (Seg @"ve0" (bindingsBinds ve0) + $.. Seg' @"i" @'[TIx] + $.. Seg @"ne0" (bindingsBinds ne0) + $.. Seg @"tl" (sD1eEnv des)) + (LSeg @"ne0" :++: LSeg @"tl") + ((LSeg @"ve0" :++: LSeg @"i" :++: LSeg @"ne0") :++: LSeg @"tl")) ne1) (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of Left ibind -> |