From 625c2c28d49dbdceb8864554acdfe1776d5333e0 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 2 Sep 2024 20:39:03 +0200
Subject: Autoweak!

---
 src/AST.hs        |  1 -
 src/AST/Weaken.hs | 51 ++++++++++++++++++++++++++++++---------------------
 src/CHAD.hs       | 53 ++++++++++++++++++++++++++++++++++++++---------------
 3 files changed, 68 insertions(+), 37 deletions(-)

(limited to 'src')

diff --git a/src/AST.hs b/src/AST.hs
index d9f5ef7..e33c10b 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -15,7 +15,6 @@
 {-# LANGUAGE EmptyCase #-}
 module AST (module AST, module AST.Weaken) where
 
-import Data.Bifunctor (first)
 import Data.Functor.Const
 import Data.Kind (Type)
 import Data.Int
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 6c66b07..42cdbd5 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -1,6 +1,7 @@
 {-# LANGUAGE DataKinds #-}
 {-# LANGUAGE EmptyCase #-}
 {-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE FunctionalDependencies #-}
 {-# LANGUAGE GADTs #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE PolyKinds #-}
@@ -97,24 +98,6 @@ wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
 wCopies SNil w = w
 wCopies (SCons _ spine) w = WCopy (wCopies spine w)
 
-wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env
-wStack WId = WId
-wStack WSink = WSink
-wStack (WCopy w) = WCopy (wStack @env w)
-wStack (WPop w) = WPop (wStack @env w)
-wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2)
-wStack (WClosed s) = wSinks s
-wStack (WIdx i) = WIdx (goIdx i)
-  where
-    goIdx :: Idx b t -> Idx (Append b env) t
-    goIdx IZ = IZ
-    goIdx (IS i') = IS (goIdx i')
-wStack (WPick @t @_ @env1 @env2 (pre :: SList (Const ()) pre) w)
-  | Refl <- lemAppendAssoc @pre @env2 @env
-  , Refl <- lemAppendAssoc @pre @(t : env1) @env
-  = WPick @t @_ pre (wStack @env w)
-wStack WSwap{} = error "OOPS"
-
 wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
 wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
 wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
@@ -127,6 +110,7 @@ type family Lookup name list where
 data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where
   LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments)
   (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2)
+infixr :++:
 
 data SSegments (segments :: [(Symbol, [t])]) where
   SSegNil :: SSegments '[]
@@ -142,6 +126,31 @@ instance (KnownSymbol name, KnownListSpine ts, KnownSegments list)
       => KnownSegments ('(name, ts) : list) where
   knownSegments = SSegCons symbolSing knownListSpine knownSegments
 
+class ToSegments k a | a -> k where
+  type SegmentsOf k a :: [(Symbol, [k])]
+  toSegments :: a -> SSegments (SegmentsOf k a)
+
+instance ToSegments k (SSegments (segments :: [(Symbol, [k])])) where
+  type SegmentsOf k (SSegments segments) = segments
+  toSegments = id
+
+data GivenSegment name ts = forall f. KnownSymbol name => Seg (SList f ts)
+                          | (KnownSymbol name, KnownListSpine ts) => Seg'
+
+instance ToSegments k (GivenSegment name (ts :: [k])) where
+  type SegmentsOf k (GivenSegment name (ts :: [k])) = '[ '(name, ts)]
+  toSegments (Seg list) = SSegCons symbolSing (slistMap (\_ -> Const ()) list) SSegNil
+  toSegments Seg' = SSegCons symbolSing knownListSpine SSegNil
+
+infixr $..
+($..) :: (ToSegments k a, ToSegments k b) => a -> b -> SSegments (Append (SegmentsOf k a) (SegmentsOf k b))
+x $.. y = ssegmentsAppend (toSegments x) (toSegments y)
+  where
+    ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
+    ssegmentsAppend SSegNil l2 = l2
+    ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2)
+
+-- | If the found segment is a TopSeg, returns Nothing.
 segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
 segmentLookup = \segs name -> case go segs name of
                            Just ts -> ts
@@ -179,9 +188,9 @@ linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ())
 linLayoutEnv _ LinEnd = SNil
 linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin)
 
-autoWeak :: forall segments env1 env2. KnownSegments segments
-         => Layout segments env1 -> Layout segments env2 -> env1 :> env2
-autoWeak ly1 ly2 = sortLinLayouts knownSegments (lineariseLayout ly1) (lineariseLayout ly2)
+autoWeak :: forall segments env1 env2.
+            SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2
+autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2)
 
 pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env
          -> r  -- Name was not found in source
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 45d2d08..97632c7 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -585,19 +585,20 @@ weakenRets w (Rets binds list) =
   in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list)
 
 rebaseRetPair :: forall env b1 b2 env0 sto t f.
-                 SList f env0 -> SList f b1 -> SList f b2
+                 Descr env0 sto -> SList f b1 -> SList f b2
               -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
-rebaseRetPair env b1 b2 (RetPair p sub d)
+rebaseRetPair descr b1 b2 (RetPair p sub d)
   | Refl <- lemAppendAssoc @b2 @b1 @env =
-      RetPair p sub (weakenExpr (autoWeak @['("d", '[D2 t]), '("b2", b2), '("b1", b1), '("tl", D2AcE (Select env0 sto "accum"))]
+      RetPair p sub (weakenExpr (autoWeak
+                                  (Seg' @"d" @'[D2 t] $.. Seg @"b2" b2 $.. Seg @"b1" b1 $.. Seg @"tl" (d2ace (select SAccum descr)))
                                   (LSeg @"d" :++: (LSeg @"b2" :++: LSeg @"tl"))
                                   (LSeg @"d" :++: ((LSeg @"b2" :++: LSeg @"b1") :++: LSeg @"tl")))
                                 d)
 
-retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
-retConcat SNil = Rets BTop SNil
-retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
-  | Rets binds1 pairs1 <- retConcat list
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
+retConcat _ SNil = Rets BTop SNil
+retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
+  | Rets binds1 pairs1 <- retConcat descr list
   , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1)
   , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0)
   , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum"))
@@ -605,7 +606,7 @@ retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
          (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
                          sub
                          (weakenExpr (WCopy (sinkWithBindings binds)) d))
-                (slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs))
+                (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)) pairs))
 
 d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
 d1op (OAdd t) e = EOp ext (OAdd t) e
@@ -731,7 +732,12 @@ drev des = \case
         (weakenExpr wbody0' body1)
         subBoth
         (ELet ext
-           (weakenExpr (WCopy (wStack @(D2AcE (Select env sto "accum")) (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))))
+           (weakenExpr (autoWeak (Seg' @"d" @'[D2 t]
+                                  $.. Seg @"body" (bindingsBinds body0)
+                                  $.. Seg @"rhs" (SCons (typeOf rhs1) (bindingsBinds rhs0))
+                                  $.. Seg @"tl" (d2ace (select SAccum des)))
+                                 (LSeg @"d" :++: LSeg @"body" :++: LSeg @"tl")
+                                 (LSeg @"d" :++: LSeg @"body" :++: LSeg @"rhs" :++: LSeg @"tl"))
                        body2') $
          ELet ext
            (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
@@ -742,7 +748,7 @@ drev des = \case
 
   EPair _ a b
     | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
-        <- retConcat $ drev des a `SCons` drev des b `SCons` SNil
+        <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil
     , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
     subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
     Ret binds
@@ -834,8 +840,14 @@ drev des = \case
                     ELet ext
                       (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
                     ELet ext
-                      (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $
-                                     WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds))
+                      (weakenExpr (autoWeak (Seg' @"d" @'[D2 t]
+                                             $.. Seg @"a0" (bindingsBinds a0)
+                                             $.. Seg @"prea0" prerebinds
+                                             $.. Seg @"recon" (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+                                             $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)
+                                             $.. Seg @"tl" (d2ace (select SAccum des)))
+                                            (LSeg @"d" :++: LSeg @"a0" :++: LSeg @"tl")
+                                            (LSeg @"d" :++: (LSeg @"a0" :++: LSeg @"prea0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl"))
                                   a2') $
                     EPair ext
                      (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
@@ -847,8 +859,14 @@ drev des = \case
                     ELet ext
                       (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
                     ELet ext
-                      (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $
-                                     WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds))
+                      (weakenExpr (autoWeak (Seg' @"d" @'[D2 t]
+                                             $.. Seg @"b0" (bindingsBinds b0)
+                                             $.. Seg @"preb0" prerebinds
+                                             $.. Seg @"recon" (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+                                             $.. Seg @"binds" (tPrimal `SCons` bindingsBinds e0)
+                                             $.. Seg @"tl" (d2ace (select SAccum des)))
+                                            (LSeg @"d" :++: LSeg @"b0" :++: LSeg @"tl")
+                                            (LSeg @"d" :++: (LSeg @"b0" :++: LSeg @"preb0") :++: LSeg @"recon" :++: LSeg @"binds" :++: LSeg @"tl"))
                                   b2') $
                     EPair ext
                       (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
@@ -900,7 +918,12 @@ drev des = \case
     Ret (bconcat (ne0 `BPush` (tIx, ne1))
                  (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
         (EBuild1 ext
-           (weakenExpr (wStack @(D1E env) (wSinks (bindingsBinds ve0) .> WSink @TIx @ne_binds))
+           (weakenExpr (autoWeak (Seg @"ve0" (bindingsBinds ve0)
+                                  $.. Seg' @"i" @'[TIx]
+                                  $.. Seg @"ne0" (bindingsBinds ne0)
+                                  $.. Seg @"tl" (sD1eEnv des))
+                                 (LSeg @"ne0" :++: LSeg @"tl")
+                                 ((LSeg @"ve0" :++: LSeg @"i" :++: LSeg @"ne0") :++: LSeg @"tl"))
                        ne1)
            (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
                                Left ibind ->
-- 
cgit v1.2.3-70-g09d2