summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
blob: 241825eff88b5c9f478a07bc040fd16db96e48cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- I want to bring various type variables in scope using type annotations in
-- patterns, but I don't want to have to mention all the other type parameters
-- of the types in question as well then. Partial type signatures (with '_') are
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
module CHAD (
  drev,
  freezeRet,
  CHADConfig(..),
  defaultConfig,
  Storage(..),
  Descr(..),
  Select,
) where

import Data.Functor.Const
import Data.Some
import Data.Type.Bool (If)
import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)

import Analysis.Identity (ValId(..), validSplitEither)
import AST
import AST.Bindings
import AST.Count
import AST.Env
import AST.Sparse
import AST.Weaken.Auto
import CHAD.EnvDescr
import CHAD.Types
import Data
import qualified Data.VarMap as VarMap
import Data.VarMap (VarMap)
import Lemmas


------------------------------ TAPES AND BINDINGS ------------------------------

type family Tape binds where
  Tape '[] = TNil
  Tape (t : ts) = TPair t (Tape ts)

tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)

bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
                    -> binds :> env2 -> Ex env2 (Tape tapebinds)
bindingsCollectTape SNil SETop _ = ENil ext
bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
  EPair ext (EVar ext t (w @> IZ))
            (bindingsCollectTape binds sub (w .> WSink))
bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
  bindingsCollectTape binds sub (w .> WSink)

-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
--                      -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-- bindingsCollectTape' binds sub w
--   | Refl <- lemAppendNil @binds
--   = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))

-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
type family TapeUnfoldings binds where
  TapeUnfoldings '[] = '[]
  TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts

type family Reverse l where
  Reverse '[] = '[]
  Reverse (t : ts) = Append (Reverse ts) '[t]

-- An expression that is always 'snd'
data UnfExpr env t where
  UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t

fromUnfExpr :: UnfExpr env t -> Ex env t
fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ)

-- - A bunch of 'snd' expressions taking us from knowing that there's a
--   'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix
--   this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in
--   the environment.
-- - In the extended environment, another bunch of let bindings (these are
--   'fst' expressions, but no need to know that statically) that project the
--   fsts out of what we introduced above, one for each type in 'ts'.
data Reconstructor env ts =
  Reconstructor
    (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts)))
    (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts)

ssnoc :: SList f ts -> f t -> SList f (Append ts '[t])
ssnoc SNil a = SCons a SNil
ssnoc (SCons t ts) a = SCons t (ssnoc ts a)

sreverse :: SList f ts -> SList f (Reverse ts)
sreverse SNil = SNil
sreverse (SCons t ts) = ssnoc (sreverse ts) t

stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts)
stapeUnfoldings SNil = SNil
stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts)

-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one.
shiftUnfolder
  :: STy t
  -> SList STy ts
  -> Bindings UnfExpr (Tape ts : env) list
  -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts])
shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts))
shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) =
  -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order
  -- to expand an 'Append' in the types so that things simplify just enough.
  -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list
  -- of bindings produced by 'b'. We want to conclude from this that
  -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know
  -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after
  -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step.
  BPush (shiftUnfolder newTy ts b) (t, case b of BTop    -> UnfExSnd itemTy t
                                                 BPush{} -> UnfExSnd itemTy t)

growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts)
growRecon t ts (Reconstructor unfbs bs)
  | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts])
  , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env)
  , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env
  = Reconstructor
      (shiftUnfolder t ts unfbs)
      -- Add a 'fst' at the bottom of the builder stack.
      -- First we have to weaken most of 'bs' to skip one more binding in the
      -- unfolder stack above it.
      (BPush (fst (weakenBindings weakenExpr
                      (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil))
                               (WSink :: env :> (Tape (t : ts) : env))) bs))
             (t
             ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $
               wSinks @(Tape (t : ts) : env)
                 (sappend ts
                          (sappend (sappend (sreverse (stapeUnfoldings ts))
                                            (SCons (tapeTy ts) SNil))
                                   SNil))
               @> IZ))

buildReconstructor :: SList STy ts -> Reconstructor env ts
buildReconstructor SNil = Reconstructor BTop BTop
buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)

-- STRATEGY FOR reconstructBindings
--
-- binds = []
-- e : ()
--
-- binds = [c]
-- e : (c, ())
-- x0 = snd x1 : ()
-- y1 = fst e  : c
--
-- binds = [b, c]
-- e : (b, (c, ()))
-- x1 = snd e  : (c, ())
-- x0 = snd x1 : ()
-- y1 = fst x1 : c
-- y2 = fst x2 : b
--
-- binds = [a, b, c]
-- e : (a, (b, (c, ())))
-- x2 = snd e  : (b, (c, ()))
-- x1 = snd x2 : (c, ())
-- x0 = snd x1 : ()
-- y1 = fst x1 : c
-- y2 = fst x2 : b
-- y3 = fst x3 : a

-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all
-- the things in the list 'binds', we want to create a let stack that extracts
-- all values from that tuple and in effect "restores" the environment
-- described by 'binds'. The idea is that elsewhere, we took a slice of the
-- environment and saved it all in a tuple to be restored later. We
-- incidentally also add a bunch of additional bindings, namely 'Reverse
-- (TapeUnfoldings binds)', so the calling code just has to skip those in
-- whatever it wants to do.
reconstructBindings :: SList STy binds -> Idx env (Tape binds)
                    -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
                       ,SList STy (Reverse (TapeUnfoldings binds)))
reconstructBindings binds tape =
  let Reconstructor unf build = buildReconstructor binds
  in (fst $ weakenBindings weakenExpr (WIdx tape)
             (bconcat (mapBindings fromUnfExpr unf) build)
     ,sreverse (stapeUnfoldings binds))


---------------------------------- DERIVATIVES ---------------------------------

d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
d1op (OAdd t) e = EOp ext (OAdd t) e
d1op (OMul t) e = EOp ext (OMul t) e
d1op (ONeg t) e = EOp ext (ONeg t) e
d1op (OLt t) e = EOp ext (OLt t) e
d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
d1op OAnd e = EOp ext OAnd e
d1op OOr e = EOp ext OOr e
d1op OIf e = EOp ext OIf e
d1op ORound64 e = EOp ext ORound64 e
d1op OToFl64 e = EOp ext OToFl64 e
d1op (ORecip t) e = EOp ext (ORecip t) e
d1op (OExp t) e = EOp ext (OExp t) e
d1op (OLog t) e = EOp ext (OLog t) e
d1op (OIDiv t) e = EOp ext (OIDiv t) e
d1op (OMod t) e = EOp ext (OMod t) e

-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
              | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))

d2op :: SOp a t -> D2Op a t
d2op op = case op of
  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
  OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
    EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
              (EOp ext (OMul t) (EPair ext (EFst ext e) d))
  ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
  OLt t -> Linear $ \_ -> pairZero t
  OLe t -> Linear $ \_ -> pairZero t
  OEq t -> Linear $ \_ -> pairZero t
  ONot -> Linear $ \_ -> ENil ext
  OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
  OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
  OIf -> Linear $ \_ -> ENil ext
  ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
  OToFl64 -> Linear $ \_ -> ENil ext
  ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
  OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
  OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
  OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
  OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
  where
    pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
    pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
                                     (EZero ext (d2M (STScal t)) (ENil ext))
      where
        ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
        ziNil STI32 k = k
        ziNil STI64 k = k
        ziNil STF32 k = k
        ziNil STF64 k = k
        ziNil STBool k = k

    d2opUnArrangeInt :: SScalTy a
                     -> (D2s a ~ TScal a => D2Op (TScal a) t)
                     -> D2Op (TScal a) t
    d2opUnArrangeInt ty float = case ty of
      STI32 -> Linear $ \_ -> ENil ext
      STI64 -> Linear $ \_ -> ENil ext
      STF32 -> float
      STF64 -> float
      STBool -> Linear $ \_ -> ENil ext

    d2opBinArrangeInt :: SScalTy a
                      -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
                      -> D2Op (TPair (TScal a) (TScal a)) t
    d2opBinArrangeInt ty float = case ty of
      STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
      STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
      STF32 -> float
      STF64 -> float
      STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)

    floatingD2 :: ScalIsFloating a ~ True
               => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
    floatingD2 STF32 k = k
    floatingD2 STF64 k = k

    integralD2 :: ScalIsIntegral a ~ True
               => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r
    integralD2 STI32 k = k
    integralD2 STI64 k = k

desD1E :: Descr env sto -> SList STy (D1E env)
desD1E = d1e . descrList

-- d1W :: env :> env' -> D1E env :> D1E env'
-- d1W WId = WId
-- d1W WSink = WSink
-- d1W (WCopy w) = WCopy (d1W w)
-- d1W (WPop w) = WPop (d1W w)
-- d1W (WThen u w) = WThen (d1W u) (d1W w)

conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)

data Idx2 env sto t
  = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
  | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
  | Idx2Di (Idx (Select env sto "discr") t)

conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
conv2Idx (DPush _   (_, _, SAccum)) IZ = Idx2Ac IZ
conv2Idx (DPush _   (_, _, SMerge)) IZ = Idx2Me IZ
conv2Idx (DPush _   (_, _, SDiscr)) IZ = Idx2Di IZ
conv2Idx (DPush des (_, _, SAccum)) (IS i) =
  case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)
                         Idx2Me j -> Idx2Me j
                         Idx2Di j -> Idx2Di j
conv2Idx (DPush des (_, _, SMerge)) (IS i) =
  case conv2Idx des i of Idx2Ac j -> Idx2Ac j
                         Idx2Me j -> Idx2Me (IS j)
                         Idx2Di j -> Idx2Di j
conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
  case conv2Idx des i of Idx2Ac j -> Idx2Ac j
                         Idx2Me j -> Idx2Me j
                         Idx2Di j -> Idx2Di (IS j)
conv2Idx DTop i = case i of {}

opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
opt2UnSparse = go . opt2
  where
    go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
    go (STScal STI32) SpAbsent = \_ -> ENil ext
    go (STScal STI64) SpAbsent = \_ -> ENil ext
    go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
    go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
    go (STScal STBool) SpAbsent = \_ -> ENil ext
    go (STScal STF32) SpScal = id
    go (STScal STF64) SpScal = id
    go STNil _ = \_ -> ENil ext
    go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
    go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"


------------------------------------ MONOIDS -----------------------------------

d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
d2zeroInfo STNil _ = ENil ext
d2zeroInfo (STPair a b) e =
  eunPair e $ \_ e1 e2 ->
    EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
d2zeroInfo STEither{} _ = ENil ext
d2zeroInfo STLEither{} _ = ENil ext
d2zeroInfo STMaybe{} _ = ENil ext
d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"

zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0))
zeroTup SNil _ = ENil ext
zeroTup (t `SCons` env) w =
  EPair ext (zeroTup env (WPop w))
            (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))


----------------------------------- SPARSITY -----------------------------------

subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
subenvD1E SETop = SETop
subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
subenvD1E (SENo sub) = SENo (subenvD1E sub)

expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
expandSparse t (SpSparse sp) epr e =
  EMaybe ext
    (EZero ext (d2M t) (d2zeroInfo t epr))
    (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
    e
expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
  eunPair epr $ \w1 epr1 epr2 ->
  eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
    EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
              (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
  ELCase ext e
    (EZero ext (d2M (STEither t1 t2)) (ENil ext))
    (ECase ext (weakenExpr WSink epr)
       (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
       (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
    (ECase ext (weakenExpr WSink epr)
       (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
       (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
  ELCase ext e
    (EZero ext (d2M (STEither t1 t2)) (ENil ext))
    (ELCase ext (weakenExpr WSink epr)
       (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
       (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
       (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
    (ELCase ext (weakenExpr WSink epr)
       (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
       (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
       (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
expandSparse (STMaybe t) (SpMaybe s) epr e =
  EMaybe ext
    (ENothing ext (d2 t))
    (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
     in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
    e
expandSparse (STArr _ t) (SpArr s) epr e =
  ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
expandSparse (STScal STF32) SpScal _ e = e
expandSparse (STScal STF64) SpScal _ e = e
expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"

subenvPlus :: SBool req1 -> SBool req2
           -> SList SMTy env
           -> SubenvS env env1 -> SubenvS env env2
           -> (forall env3. SubenvS env env3
                         -> Injection req1 (Tup env1) (Tup env3)
                         -> Injection req2 (Tup env2) (Tup env3)
                         -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
                         -> r)
           -> r
subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext)

subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
  subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
    k (SENo sub3) s31 s32 pl

subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
  subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
    k (SEYes sp1 sub3)
      (withInj minj13 $ \inj13 ->
        \e1 -> eunPair e1 $ \_ e1a e1b ->
          EPair ext (inj13 e1a) e1b)
      Noinj
      (\e1 e2 ->
        ELet ext e1 $
          EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
                        (weakenExpr WSink e2))
                    (ESnd ext (EVar ext (typeOf e1) IZ)))
subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k =
  subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
    k (SEYes (SpSparse sp1) sub3)
      (withInj minj13 $ \inj13 ->
        \e1 -> eunPair e1 $ \_ e1a e1b ->
          EPair ext (inj13 e1a) (EJust ext e1b))
      (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
      (\e1 e2 ->
        ELet ext e1 $
          EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
                        (weakenExpr WSink e2))
                    (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))

subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
  subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
    k sub3 minj13 minj23 (flip pl)

subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
  subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
  sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
    k (SEYes sp3 sub3)
      (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
        \e1 -> eunPair e1 $ \_ e1a e1b ->
          EPair ext (inj13 e1a) (tinj13 e1b))
      (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
        \e2 -> eunPair e2 $ \_ e2a e2b ->
          EPair ext (inj23 e2a) (tinj23 e2b))
      (\e1 e2 ->
        ELet ext e1 $
        ELet ext (weakenExpr WSink e2) $
          EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
                        (EFst ext (EVar ext (typeOf e2) IZ)))
                    (plus
                      (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
                      (ESnd ext (EVar ext (typeOf e2) IZ))))

expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
                  -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
expandSubenvZeros _ SNil SETop _ = ENil ext
expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
  eunPair e $ \w1 e1 e2 ->
    EPair ext
      (expandSubenvZeros (w1 .> WPop w) ts sub e1)
      (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
expandSubenvZeros w (SCons t ts) (SENo sub) e =
  EPair ext
    (expandSubenvZeros (WPop w) ts sub e)
    (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))

assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"


--------------------------------- ACCUMULATORS ---------------------------------

makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators _ SNil e = e
makeAccumulators w (t `SCons` envpro) e =
  makeAccumulators (WPop w) envpro $
    EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e

uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
uninvertTup (t `SCons` list) tcore e =
  ELet ext (uninvertTup list (STPair tcore t) e) $
    let recT = STPair (STPair tcore t) (tTup list)  -- type of the RHS of that let binding
    in EPair ext
         (EFst ext (EFst ext (EVar ext recT IZ)))
         (EPair ext
            (ESnd ext (EVar ext recT IZ))
            (ESnd ext (EFst ext (EVar ext recT IZ))))

fromArrayValId :: Maybe (ValId t) -> Maybe Int
fromArrayValId (Just (VIArr i _)) = Just i
fromArrayValId _ = Nothing

accumPromote :: forall dt env sto proxy r.
                proxy dt
             -> Descr env sto
             -> (forall stoRepl envPro.
                    (Select env stoRepl "merge" ~ '[])
                 => Descr env stoRepl
                      -- ^ A revised environment description that switches
                      -- arrays (used in the OccEnv) that are currently on
                      -- "merge" storage, to "accum" storage.
                 -> SList STy envPro
                      -- ^ New entries on top of the original dual environment,
                      -- that house the accumulators for the promoted arrays in
                      -- the original environment.
                 -> Subenv (Select env sto "merge") envPro
                      -- ^ The promoted entries were merge entries in the
                      -- original environment.
                 -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum"))
                      -- ^ All entries that were accumulators are still
                      -- accumulators.
                 -> VarMap Int (D2AcE (Select env stoRepl "accum"))
                      -- ^ Accumulator map for _only_ the the newly allocated
                      -- accumulators.
                 -> (forall shbinds.
                            SList STy shbinds
                         -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
                            :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
                      -- ^ A weakening that converts a computation in the
                      -- revised environment to one in the original environment
                      -- extended with some accumulators.
                 -> r)
             -> r
accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId)
accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
  -- Accumulators are left as-is
  SAccum ->
    accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
      k (storepl `DPush` (t, vid, SAccum))
        envpro
        prosub
        (SEYesR accrevsub)
        (VarMap.sink1 accumMap)
        (\shbinds ->
          autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
                   (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
                   (#pro :++: #d :++: #shb :++: #acc :++: #tl)
          .> WCopy (wf shbinds)
          .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
                      (#d :++: #shb :++: #acc :++: #tl)
                      (#acc :++: (#d :++: #shb :++: #tl)))

  SMerge -> case t of
    -- Discrete values are left as-is
    _ | isDiscrete t ->
      accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf ->
        k (storepl `DPush` (t, vid, SDiscr))
          envpro
          (SENo prosub)
          accrevsub
          accumMap'
          wf

    -- Values with "merge" storage are promoted to an accumulator in envPro
    _ ->
      accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
        k (storepl `DPush` (t, vid, SAccum))
          (t `SCons` envpro)
          (SEYesR prosub)
          (SENo accrevsub)
          (let accumMap' = VarMap.sink1 accumMap
           in case fromArrayValId vid of
                Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap'
                Nothing -> accumMap')
          (\(shbinds :: SList _ shbinds) ->
            let shbindsC = slistMap (\_ -> Const ()) shbinds
            in
            -- wf:
            --                 D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>                Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
            -- WCopy wf:
            --   TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
            --                       WPICK: ^                                                                 THESE TWO  ||
            -- goal:                        |                                                                 ARE EQUAL  ||
            --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
            WCopy (wf shbinds)
            .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
                 (WId @(D2AcE (Select env1 stoRepl "accum"))))

  -- Discrete values are left as-is, nothing to do
  SDiscr ->
    accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
      k (storepl `DPush` (t, vid, SDiscr))
        envpro
        prosub
        accrevsub
        accumMap
        wf
  where
    isDiscrete :: STy t' -> Bool
    isDiscrete = \case
      STNil -> True
      STPair a b -> isDiscrete a && isDiscrete b
      STEither a b -> isDiscrete a && isDiscrete b
      STLEither a b -> isDiscrete a && isDiscrete b
      STMaybe a -> isDiscrete a
      STArr _ a -> isDiscrete a
      STScal st -> case st of
        STI32 -> True
        STI64 -> True
        STF32 -> False
        STF64 -> False
        STBool -> True
      STAccum{} -> False


---------------------------- RETURN TRIPLE FROM CHAD ---------------------------

data Ret env0 sto sd t =
  forall shbinds tapebinds contribs.
    Ret (Bindings Ex (D1E env0) shbinds)  -- shared binds
        (Subenv shbinds tapebinds)
        (Ex (Append shbinds (D1E env0)) (D1 t))
        (SubenvS (D2E (Select env0 sto "merge")) contribs)
        (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
deriving instance Show (Ret env0 sto sd t)

type data TyTyPair = MkTyTyPair Ty Ty

data SingleRet env0 sto (pair :: TyTyPair) =
  forall shbinds tapebinds.
    SingleRet
      (Bindings Ex (D1E env0) shbinds)  -- shared binds
      (Subenv shbinds tapebinds)
      (RetPair env0 sto (D1E env0) shbinds tapebinds pair)

-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
--              -> Subenv shbinds tapebinds
--              -> Ex (Append shbinds (D1E env0)) (D1 t)
--              -> SubenvS (D2E (Select env0 sto "merge")) contribs
--              -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
--              -> SingleRet env0 sto (MkTyTyPair sd t)
-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
-- {-# COMPLETE Ret1 #-}

data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
  RetPair :: forall sd t contribs  -- existentials
                    env0 sto env shbinds tapebinds.  -- universals
             Ex (Append shbinds env) (D1 t)
          -> SubenvS (D2E (Select env0 sto "merge")) contribs
          -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
          -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)

data Rets env0 sto env list =
  forall shbinds tapebinds.
    Rets (Bindings Ex env shbinds)
         (Subenv shbinds tapebinds)
         (SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)

toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)

weakenRetPair :: SList STy shbinds -> env :> env'
              -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2

weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
weakenRets w (Rets binds tapesub list) =
  let (binds', _) = weakenBindings weakenExpr w binds
  in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)

rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
                 Descr env0 sto
              -> SList f b1 -> SList f b2
              -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
              -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
              -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
  | Refl <- lemAppendAssoc @b2 @b1 @env =
      RetPair e1 sub
              (weakenExpr (autoWeak
                            (#d (auto1 @sd)
                             &. #t2 (subList b2 subtape2)
                             &. #t1 (subList b1 subtape1)
                             &. #tl (d2ace (select SAccum descr)))
                            (#d :++: (#t2 :++: #tl))
                            (#d :++: ((#t2 :++: #t1) :++: #tl)))
                e2)

retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat _ SNil = Rets BTop SETop SNil
retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
  | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
      <- weakenRets (sinkWithBindings e0) (retConcat descr list)
  , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
  , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
  = Rets (bconcat e0 binds)
         (subenvConcat subtape subtape2)
         (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
                         sub
                         (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
                (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
                                               subtape subtape2)
                          pairs))

freezeRet :: Descr env sto
          -> Ret env sto (D2 t) t
          -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
  let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
      e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
      tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
      library = #d (auto1 @(D2 t))
                &. #tape (subList (bindingsBinds e0) subtape)
                &. #shbinds (bindingsBinds e0)
                &. #d2ace (d2ace (select SAccum descr))
                &. #tl (desD1E descr)
                &. #contribs (SCons tContribs SNil)
  in letBinds e0' $
       EPair ext
         (weakenExpr wInsertD2Ac e1)
         (ELet ext (weakenExpr (autoWeak library
                                         (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
                                         (#shbinds :++: #d :++: #d2ace :++: #tl))
                      e2') $
          expandSubenvZeros
            (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
             .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
            (select SMerge descr) sub (EVar ext tContribs IZ))


---------------------------- THE CHAD TRANSFORMATION ---------------------------

drev :: forall env sto sd t.
        (?config :: CHADConfig)
     => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
     -> Sparse (D2 t) sd
     -> Expr ValId env t -> Ret env sto sd t
drev des _ sd | isAbsent sd =
  \e ->
    Ret BTop
        SETop
        (drevPrimal des e)
        (subenvNone (d2e (select SMerge des)))
        (ENil ext)
drev _ _ SpAbsent = error "Absent should be isAbsent"

drev des accumMap (SpSparse sd) =
  \e ->
    case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
    subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
      Ret e0
          subtape
          e1
          sub'
          (emaybe (evar IZ)
            (inj2 (ENil ext))
            (inj1 (weakenExpr (WCopy WSink) e2)))
    }

drev des accumMap sd = \case
  EVar _ t i ->
    case conv2Idx des i of
      Idx2Ac accI ->
        Ret BTop
            SETop
            (EVar ext (d1 t) (conv1Idx i))
            (subenvNone (d2e (select SMerge des)))
            (let ty = applySparse sd (d2M t)
             in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))

      Idx2Me tupI ->
        Ret BTop
            SETop
            (EVar ext (d1 t) (conv1Idx i))
            (subenvOnehot (d2e (select SMerge des)) tupI sd)
            (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))

      Idx2Di _ ->
        Ret BTop
            SETop
            (EVar ext (d1 t) (conv1Idx i))
            (subenvNone (d2e (select SMerge des)))
            (ENil ext)

  ELet _ (rhs :: Expr _ _ a) body
    | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
    , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
    , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
    , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
    , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
    , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
    ->
    subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
    let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
    Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
        (subenvConcat subtapeRHS subtapeBody)
        (weakenExpr wbody0' body1)
        subBoth
        (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
                                         &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
                                         &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
                                         &. #tl (d2ace (select SAccum des)))
                                        (#d :++: #body :++: #tl)
                                        (#d :++: (#body :++: #rhs) :++: #tl))
                              body2) $
         ELet ext
           (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
             weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
         plus_RHS_Body
           (EVar ext (contribTupTy des subRHS) IZ)
           (EFst ext (EVar ext bodyResType (IS IZ))))

  EPair _ a b
    | SpPair sd1 sd2 <- sd
    , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
        <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
    , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
    subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
    Ret binds
        subtape
        (EPair ext a1 b1)
        subBoth
        (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
                   (weakenExpr (WCopy WSink) a2)) $
         ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
                   (weakenExpr (WCopy (WSink .> WSink)) b2)) $
         plus_A_B
           (EVar ext (contribTupTy des subA) (IS IZ))
           (EVar ext (contribTupTy des subB) IZ))

  EFst _ e
    | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
    , STPair t1 _ <- typeOf e ->
    Ret e0
        subtape
        (EFst ext e1)
        sub
        (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
           weakenExpr (WCopy WSink) e2)

  ESnd _ e
    | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
    , STPair _ t2 <- typeOf e ->
    Ret e0
        subtape
        (ESnd ext e1)
        sub
        (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
           weakenExpr (WCopy WSink) e2)

  -- Don't need to handle ENil, because its cotangent is always absent!
  -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)

  EInl _ t2 e
    | SpLEither sd1 sd2 <- sd
    , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
    subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
    Ret e0
        subtape
        (EInl ext (d1 t2) e1)
        sub'
        (ELCase ext
           (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
           (inj2 $ ENil ext)
           (inj1 $ weakenExpr (WCopy WSink) e2)
           (EError ext (contribTupTy des sub') "inl<-dinr"))

  EInr _ t1 e
    | SpLEither sd1 sd2 <- sd
    , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
    subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
    Ret e0
        subtape
        (EInr ext (d1 t1) e1)
        sub'
        (ELCase ext
           (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
           (inj2 $ ENil ext)
           (EError ext (contribTupTy des sub') "inr<-dinl")
           (inj1 $ weakenExpr (WCopy WSink) e2))

  ECase _ e (a :: Expr _ _ t) b
    | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
    , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
    , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
    , let (bindids1, bindids2) = validSplitEither (extOf e)
    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
          <- drevScoped des accumMap t1 storage1 bindids1 sd a
    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
          <- drevScoped des accumMap t2 storage2 bindids2 sd b
    , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
    , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
    , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
    , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
    , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
    , let tapeA = tapeTy subtapeListA
    , let tapeB = tapeTy subtapeListB
    , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
                                         (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
    , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
                                         (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
    , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
    , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
    , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
    , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
    , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
    , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
    , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
    , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
    , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
    ->
    subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
    subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
    Ret (e0 `BPush`
         (tPrimal,
            ECase ext e1
              (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
              (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))))
        (SEYesR subtapeE)
        (EFst ext (EVar ext tPrimal IZ))
        subOut
        (elet
           (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
              (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ
               in letBinds rebinds $
                    ELet ext
                      (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
                    elet
                      (weakenExpr (autoWeak (#d (auto1 @sd)
                                             &. #ta0 subtapeListA
                                             &. #prea0 prerebinds
                                             &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
                                             &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
                                             &. #tl (d2ace (select SAccum des)))
                                            (#d :++: #ta0 :++: #tl)
                                            (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
                                  a2) $
                    EPair ext (sAB_A $ EFst ext (evar IZ))
                              (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
              (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ
               in letBinds rebinds $
                    ELet ext
                      (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
                    elet
                      (weakenExpr (autoWeak (#d (auto1 @sd)
                                             &. #tb0 subtapeListB
                                             &. #preb0 prerebinds
                                             &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
                                             &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
                                             &. #tl (d2ace (select SAccum des)))
                                            (#d :++: #tb0 :++: #tl)
                                            (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
                                  b2) $
                    EPair ext (sAB_B $ EFst ext (evar IZ))
                              (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
         plus_AB_E
           (EFst ext (evar IZ))
           (ELet ext (ESnd ext (evar IZ)) $
              weakenExpr (WCopy (wSinks' @[_,_,_])) e2))

  EConst _ t val ->
    Ret BTop
        SETop
        (EConst ext t val)
        (subenvNone (d2e (select SMerge des)))
        (ENil ext)

  EOp _ op e
    | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
    case d2op op of
      Linear d2opfun ->
        Ret e0
            subtape
            (d1op op e1)
            sub
            (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
               (weakenExpr (WCopy WSink) e2))
      Nonlinear d2opfun ->
        Ret (e0 `BPush` (d1 (typeOf e), e1))
            (SEYesR subtape)
            (d1op op $ EVar ext (d1 (typeOf e)) IZ)
            sub
            (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
                               (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
               (weakenExpr (WCopy (wSinks' @[_,_])) e2))

  ECustom _ _ tb storety srce pr du a b
    -- allowed to ignore a2 because 'a' is the part of the input that is inactive
    | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
    case isDense (d2M (typeOf srce)) sd of
      Just Refl ->
        Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
                `BPush` (typeOf b1, weakenExpr WSink b1)
                `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
                `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
            (SEYesR (SENo (SENo (SENo bsubtape))))
            (EFst ext (EVar ext (typeOf pr) (IS IZ)))
            bsub
            (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
               weakenExpr (WCopy (WSink .> WSink)) b2)

      Nothing ->
        Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
                `BPush` (typeOf b1, weakenExpr WSink b1)
                `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)))
            (SEYesR (SENo (SENo bsubtape)))
            (EFst ext (EVar ext (typeOf pr) IZ))
            bsub
            (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $  -- tape
             ELet ext (expandSparse (typeOf srce) sd  -- expanded incoming cotangent
                                    (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
                                    (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
             ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
               weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)

  ERecompute _ e ->
    deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
    let smallE = unsafeWeakenWithSubenv usedSub e in
    subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
    case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
    let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
    Ret (collectBindings (desD1E des) subD1eUsed)
        (subenvAll (desD1E usedDes))
        (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
        (subenvCompose subMergeUsed' sub)
        (letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
           weakenExpr
             (autoWeak (#d (auto1 @sd)
                        &. #shbinds (bindingsBinds e0)
                        &. #tape (subList (bindingsBinds e0) subtape)
                        &. #d1env (desD1E usedDes)
                        &. #tl' (d2ace (select SAccum usedDes))
                        &. #tl (d2ace (select SAccum des)))
                       (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed))
                       (#shbinds :++: #d :++: #d1env :++: #tl))
             e2)
    }

  EError _ t s ->
    Ret BTop
        SETop
        (EError ext (d1 t) s)
        (subenvNone (d2e (select SMerge des)))
        (ENil ext)

  EConstArr _ n t val ->
    Ret BTop
        SETop
        (EConstArr ext n t val)
        (subenvNone (d2e (select SMerge des)))
        (ENil ext)

  EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
    | SpArr @_ @sdElt sdElt <- sd
    , let eltty = typeOf orige
    , shty :: STy shty <- tTup (sreplicate ndim tIx)
    , Refl <- indexTupD1Id ndim ->
    deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
    let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
    subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
    accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
    let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
    case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
    case assertSubenvEmpty sub of { Refl ->
    case lemAppendNil @e_binds of { Refl ->
    let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
    let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
    Ret (BTop `BPush` (shty, drevPrimal des she)
              `BPush` (STArr ndim (STPair (d1 eltty) tapety)
                      ,EBuild ext ndim
                         (EVar ext shty IZ)
                         (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
                                                                              &. #sh (shty `SCons` SNil)
                                                                              &. #d1env (desD1E des)
                                                                              &. #d1env' (desD1E usedDes))
                                                                             (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
                                                                             (#ix :++: #sh :++: #d1env))
                                                                   e0)) $
                            let w = autoWeak (#ix (shty `SCons` SNil)
                                              &. #sh (shty `SCons` SNil)
                                              &. #e0 (bindingsBinds e0)
                                              &. #d1env (desD1E des)
                                              &. #d1env' (desD1E usedDes))
                                             (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
                                             (#e0 :++: #ix :++: #sh :++: #d1env)
                                w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
                            in EPair ext (weakenExpr w e1) (collectexpr w')))
              `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
                                               (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
        (SEYesR (SENo (SEYesR SETop)))
        (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
              (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
        (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
        (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
         ESnd ext $
           uninvertTup (d2e envPro) (STArr ndim STNil) $
             -- TODO: what's happening here is that because of the sparsity
             -- rewrite, makeAccumulators needs primals where it previously
             -- didn't. The build derivative is currently not saving those
             -- primals, so the hole below cannot currently be filled. The
             -- appropriate primals (waves hands) need to be stored, so that a
             -- weakening can be provided here.
             makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $
               EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
                 -- the cotangent for this element
                 ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
                                    (EVar ext shty IZ)) $
                 -- the tape for this element
                 ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
                                    (EVar ext shty (IS IZ))) $
                 let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
                 in letBinds rebinds $
                      weakenExpr (autoWeak (#d (auto1 @sdElt)
                                            &. #pro (d2ace envPro)
                                            &. #etape (subList (bindingsBinds e0) subtapeE)
                                            &. #prerebinds prerebinds
                                            &. #tape (auto1 @(Tape e_tape))
                                            &. #ix (auto1 @shty)
                                            &. #darr (auto1 @(TArr ndim sdElt))
                                            &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
                                            &. #sh (auto1 @shty)
                                            &. #d2acUsed (d2ace (select SAccum usedDes))
                                            &. #d2acEnv (d2ace (select SAccum des)))
                                           (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
                                           ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
                                  .> wPro (subList (bindingsBinds e0) subtapeE))
                                 e2)
    }}}

  EUnit _ e
    | SpArr sdElt <- sd
    , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
    Ret e0
        subtape
        (EUnit ext e1)
        sub
        (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
           weakenExpr (WCopy WSink) e2)

  EReplicate1Inner _ en e
    -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
    | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
        <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil
    , let STArr ndim eltty = typeOf e ->
    Ret binds
        subtape
        (EReplicate1Inner ext en1 e1)
        sub
        (EMaybe ext
          (zeroTup (subList (select SMerge des) sub))
          (ELet ext (EJust ext (EFold1Inner ext Commut
                        (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
                        (ezeroD2 eltty)
                        (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
            weakenExpr (WCopy (WSink .> WSink)) e2)
          (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))

  EIdx0 _ e
    | Ret e0 subtape e1 sub e2 <- drev des accumMap e
    , STArr _ t <- typeOf e ->
    Ret e0
        subtape
        (EIdx0 ext e1)
        sub
        (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
         weakenExpr (WCopy WSink) e2)

  EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
  {-
  EIdx1 _ e ei
    -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
    | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
        <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
    , STArr (SS n) eltty <- typeOf e ->
    Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
               `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
        (SEYesR (SENo subtape))
        (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
                   (weakenExpr (WSink .> WSink) ei1))
        sub
        (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
                               (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
                               (EVar ext (STArr n (d2 eltty)) (IS IZ))) $
         weakenExpr (WCopy (WSink .> WSink)) e2)
  -}

  EIdx _ e ei
    -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
    | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
        <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
    , STArr n eltty <- typeOf e
    , Refl <- indexTupD1Id n
    , Refl <- lemZeroInfoD2 eltty
    , let tIxN = tTup (sreplicate n tIx)  ->
    Ret (binds `BPush` (STArr n (d1 eltty), e1)
               `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
               `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
        (SEYesR (SEYesR (SENo subtape)))
        (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
                  (EVar ext (tTup (sreplicate n tIx)) IZ))
        sub
        (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere))
                             (EPair ext (EPair ext (EVar ext tIxN (IS IZ))
                                                   (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext)))
                                        (ENil ext))
                             (EVar ext (d2 eltty) IZ)) $
         weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)

  EShape _ e
    -- Allowed to ignore e2 here because the output of EShape is discrete,
    -- hence we'd be passing a zero cotangent to e2 anyway.
    | Ret e0 subtape e1 _ _ <- drev des accumMap e
    , STArr n _ <- typeOf e
    , Refl <- indexTupD1Id n ->
    Ret e0
        subtape
        (EShape ext e1)
        (subenvNone (select SMerge des))
        (ENil ext)

  ESum1Inner _ e
    | Ret e0 subtape e1 sub e2 <- drev des accumMap e
    , STArr (SS n) t <- typeOf e ->
    Ret (e0 `BPush` (STArr (SS n) t, e1)
            `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
        (SEYesR (SENo subtape))
        (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
        sub
        (EMaybe ext
          (zeroTup (subList (select SMerge des) sub))
          (ELet ext (EJust ext (EReplicate1Inner ext
                                  (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
                                  (EVar ext (STArr n (d2 t)) IZ))) $
           weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
          (EVar ext (d2 (STArr n t)) IZ))

  EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
  EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e

  -- These should be the next to be implemented, I think
  EFold1Inner{} -> err_unsupported "EFold1Inner"

  ENothing{} -> err_unsupported "ENothing"
  EJust{} -> err_unsupported "EJust"
  EMaybe{} -> err_unsupported "EMaybe"
  ELNil{} -> err_unsupported "ELNil"
  ELInl{} -> err_unsupported "ELInl"
  ELInr{} -> err_unsupported "ELInr"
  ELCase{} -> err_unsupported "ELCase"

  EWith{} -> err_accum
  EZero{} -> err_monoid
  EPlus{} -> err_monoid
  EOneHot{} -> err_monoid

  where
    err_accum = error "Accumulator operations unsupported in the source program"
    err_monoid = error "Monoid operations unsupported in the source program"
    err_unsupported s = error $ "CHAD: unsupported " ++ s

    deriv_extremum :: ScalIsNumeric t' ~ True
                   => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
                   -> Sparse (TArr n (D2s t')) sd'
                   -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t'))
    deriv_extremum extremum e
      | Ret e0 subtape e1 sub e2 <- drev des accumMap e
      , at@(STArr (SS n) t@(STScal st)) <- typeOf e
      , let at' = STArr n t
      , let tIxN = tTup (sreplicate (SS n) tIx) =
      Ret (e0 `BPush` (at, e1)
              `BPush` (at', extremum (EVar ext at IZ)))
          (SEYesR (SEYesR subtape))
          (EVar ext at' IZ)
          sub
          (EMaybe ext
            (zeroTup (subList (select SMerge des) sub))
            (ELet ext (EJust ext
                        (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
                           eif (EOp ext (OEq st) (EPair ext
                                        (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
                                        (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
                             (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
                             (ezeroD2 t))) $
              weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
            (EVar ext (d2 at') IZ))

    contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
    contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))

data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)

data RetScoped env0 sto a s sd t =
  forall shbinds tapebinds contribs sa.
    RetScoped
        (Bindings Ex (D1E (a : env0)) shbinds)  -- shared binds
        (Subenv (Append shbinds '[D1 a]) tapebinds)
        (Ex (Append shbinds (D1E (a : env0))) (D1 t))
        (SubenvS (D2E (Select env0 sto "merge")) contribs)
           -- ^ merge contributions to the _enclosing_ merge environment
        (Sparse (D2 a) sa)
           -- ^ contribution to the argument
        (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
            (If (s == "discr") (Tup contribs)
                               (TPair (Tup contribs) sa)))
          -- ^ the merge contributions, plus the cotangent to the argument
          -- (if there is any)
deriving instance Show (RetScoped env0 sto a s sd t)

drevScoped :: forall a s env sto sd t.
              (?config :: CHADConfig)
           => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
           -> STy a -> Storage s -> Maybe (ValId a)
           -> Sparse (D2 t) sd
           -> Expr ValId (a : env) t
           -> RetScoped env sto a s sd t
drevScoped des accumMap argty argsto argids sd expr = case argsto of
  SMerge
    | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
    , Refl <- lemAppendNil @tapebinds ->
        case sub of
          SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
          SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))

  SAccum
    | Just (VIArr i _) <- argids
    , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
    , Just Refl <- testEquality foundTy (STAccum (d2M argty))
    , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
    , Refl <- lemAppendNil @tapebinds ->
        -- Our contribution to the binding's cotangent _here_ is zero (absent),
        -- because we're contributing to an earlier binding of the same value
        -- instead.
        RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
          let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
          ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
            weakenExpr (autoWeak (#d (auto1 @sd)
                                    &. #body (subList (bindingsBinds e0) subtape)
                                    &. #ac (auto1 @(TAccum (D2 a)))
                                    &. #tl (d2ace (select SAccum des)))
                                   (#d :++: #body :++: #ac :++: #tl)
                                   (#ac :++: #d :++: #body :++: #tl))
                       (EPair ext e2 (ENil ext))

    | let accumMap' = case argids of
                        Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
                        _ -> VarMap.sink1 accumMap
    , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
        let library = #d (auto1 @sd)
                      &. #p (auto1 @(D1 a))
                      &. #body (subList (bindingsBinds e0) subtape)
                      &. #ac (auto1 @(TAccum (D2 a)))
                      &. #tl (d2ace (select SAccum des))
        in
        RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $
          let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
          EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $
            weakenExpr (autoWeak library
                                 (#d :++: #body :++: #ac :++: #tl)
                                 (#ac :++: #d :++: (#body :++: #p) :++: #tl))
                       e2

  SDiscr
    | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
    , Refl <- lemAppendNil @tapebinds ->
        RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2

-- TODO: proper primal-only transform that doesn't depend on D1 = Id
drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
  | Refl <- chadD1Id (typeOf e)
  , Refl <- chadD1EId (descrList des)
  = mapExt (const ext) e
  where
    chadD1Id :: STy a -> D1 a :~: a
    chadD1Id STNil = Refl
    chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
    chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
    chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
    chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl
    chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl
    chadD1Id (STScal _) = Refl
    chadD1Id STAccum{} = error "accumulators not allowed in source program"

    chadD1EId :: SList STy l -> D1E l :~: l
    chadD1EId SNil = Refl
    chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl