From 8d07a43f0b364156433dc453b9d1cc762c032634 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Thu, 21 Sep 2023 09:29:05 +0200
Subject: WIP mixed environment description

---
 src/CHAD.hs | 156 ++++++++++++++++++++++++++++++++++++++++--------------------
 1 file changed, 104 insertions(+), 52 deletions(-)

diff --git a/src/CHAD.hs b/src/CHAD.hs
index 9a1c7d2..0c856b1 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -9,8 +9,14 @@
 {-# LANGUAGE TypeOperators #-}
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
 module CHAD where
 
+import Data.Bifunctor (first, second)
+import Data.Kind (Type)
+import GHC.TypeLits (Symbol)
+
 import AST
 
 
@@ -104,6 +110,12 @@ type family D2E env where
   D2E '[] = '[]
   D2E (t : env) = D2 t : D2E env
 
+-- | Select only the types from the environment that have the specified storage
+type family Select env sto s where
+  Select '[] '[] _ = '[]
+  Select (t : ts) (s : sto) s = t : Select ts sto s
+  Select (_ : ts) (_ : sto) s = Select ts sto s
+
 d1 :: STy t -> STy (D1 t)
 d1 STNil = STNil
 d1 (STPair a b) = STPair (d1 a) (d1 b)
@@ -125,21 +137,17 @@ d2 (STScal t) = case t of
   STBool -> STNil
 d2 STEVM{} = error "EVM not allowed in input program"
 
-d2e :: SList STy list -> SList STy (D2E list)
-d2e SNil = SNil
-d2e (SCons t list) = SCons (d2 t) (d2e list)
-
-d2list :: SList STy env -> SList STy (D2E env)
-d2list SNil = SNil
-d2list (SCons x l) = SCons (d2 x) (d2list l)
-
 conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
 conv1Idx IZ = IZ
 conv1Idx (IS i) = IS (conv1Idx i)
 
-conv2Idx :: Idx env t -> Idx (D2E env) (D2 t)
-conv2Idx IZ = IZ
-conv2Idx (IS i) = IS (conv2Idx i)
+conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t))
+                                                 (Idx (D2E (Select env sto "merge")) (D2 t))
+conv2Idx (DPush _   _ SAccum) IZ = Left IZ
+conv2Idx (DPush _   _ SMerge) IZ = Right IZ
+conv2Idx (DPush des _ SAccum) (IS i) = first IS (conv2Idx des i)
+conv2Idx (DPush des _ SMerge) (IS i) = second IS (conv2Idx des i)
+conv2Idx DTop i = case i of {}
 
 zero :: STy t -> Ex env (D2 t)
 zero STNil = ENil ext
@@ -154,23 +162,35 @@ zero (STScal t) = case t of
   STBool -> ENil ext
 zero STEVM{} = error "EVM not allowed in input program"
 
-data Ret env t =
+type family Tup env where
+  Tup '[] = TNil
+  Tup (t : ts) = TPair t (Tup ts)
+
+tTup :: SList STy env -> STy (Tup env)
+tTup SNil = STNil
+tTup (SCons t ts) = STPair t (tTup ts)
+
+zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
+zeroTup SNil = ENil ext
+zeroTup (SCons t env) = EPair ext (zero t) (zeroTup env)
+
+data Ret env sto t =
   forall env'.
     Ret (Bindings Ex (D1E env) env')
         (Ex env' (D1 t))
-        (Ex (D2 t : env') (TEVM (D2E env) TNil))
-deriving instance Show (Ret env t)
+        (Ex (D2 t : env') (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge")))))
+deriving instance Show (Ret env sto t)
 
-data RetPair env0 env t =
+data RetPair env0 sto env t =
     RetPair (Ex env (D1 t))
-            (Ex (D2 t : env) (TEVM (D2E env0) TNil))
+            (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E (Select env0 sto "merge")))))
   deriving (Show)
 
-data Rets env0 env list =
+data Rets env0 sto env list =
   forall env'.
     Rets (Bindings Ex env env')
-         (SList (RetPair env0 env') list)
-deriving instance Show (Rets env0 env list)
+         (SList (RetPair env0 sto env') list)
+deriving instance Show (Rets env0 sto env list)
 
 -- d1W :: env :> env' -> D1E env :> D1E env'
 -- d1W WId = WId
@@ -179,15 +199,15 @@ deriving instance Show (Rets env0 env list)
 -- d1W (WPop w) = WPop (d1W w)
 -- d1W (WThen u w) = WThen (d1W u) (d1W w)
 
-weakenRetPair :: env :> env' -> RetPair env0 env t -> RetPair env0 env' t
+weakenRetPair :: env :> env' -> RetPair env0 sto env t -> RetPair env0 sto env' t
 weakenRetPair w (RetPair e1 e2) = RetPair (weakenExpr w e1) (weakenExpr (WCopy w) e2)
 
-weakenRets :: env :> env' -> Rets env0 env list -> Rets env0 env' list
+weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
 weakenRets w (Rets binds list) =
   weakenBindings weakenExpr w binds $ \binds' wbinds' ->
     Rets binds' (slistMap (weakenRetPair wbinds') list)
 
-retConcat :: forall env list. SList (Ret env) list -> Rets env (D1E env) list
+retConcat :: forall env sto list. SList (Ret env sto) list -> Rets env sto (D1E env) list
 retConcat SNil = Rets BTop SNil
 retConcat (SCons (Ret (b :: Bindings Ex (D1E env) env2) p d) list)
   | Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list)
@@ -206,10 +226,10 @@ d1op (OEq t) e = EOp ext (OEq t) e
 d1op ONot e = EOp ext ONot e
 d1op OIf e = EOp ext OIf e
 
+-- | Both primal and dual must be duplicable expressions
 data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
               | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))
 
--- both primal and dual must be duplicable expressions
 d2op :: SOp a t -> D2Op a t
 d2op op = case op of
   OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d)
@@ -243,21 +263,50 @@ d2op op = case op of
       STF64 -> float
       STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
 
-freezeRet :: Ret env t
+freezeRet :: Ret env sto t
           -> (forall env'. Ex env' (D2 t))  -- the incoming cotangent value
-          -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E env) TNil))
+          -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge")))))
 freezeRet (Ret e0 e1 e2) d = letBinds e0 $ EPair ext e1 (ELet ext d e2)
 
-drev :: SList STy env -> Ex env t -> Ret env t
-drev senv = \case
+type Storage :: Symbol -> Type
+data Storage s where
+  SAccum :: Storage "accum"  -- ^ in the monad state as a mutable accumulator
+  SMerge :: Storage "merge"  -- ^ just return and merge
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+  DTop :: Descr '[] '[]
+  DPush :: Descr env sto -> STy t -> Storage s -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des t SAccum) = SCons t (select s des)
+select s@SMerge (DPush des _ SAccum) = select s des
+select s@SAccum (DPush des _ SMerge) = select s des
+select s@SMerge (DPush des t SMerge) = SCons t (select s des)
+
+d2e :: SList STy env -> SList STy (D2E env)
+d2e SNil = SNil
+d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+
+drev :: Descr env sto -> Ex env t -> Ret env sto t
+drev des = \case
   EVar _ t i ->
-    Ret BTop
-        (EVar ext (d1 t) (conv1Idx i))
-        (EMOne (d2list senv) (conv2Idx i) (EVar ext (d2 t) IZ))
+    case conv2Idx des i of
+      Left accumI ->
+        Ret BTop
+            (EVar ext (d1 t) (conv1Idx i))
+            (EMBind
+              (EMOne d2mon accumI (EVar ext (d2 t) IZ))
+              (EMReturn d2mon (zeroTup (select SMerge des))))
+      Right tupI ->
+        _
 
   ELet _ rhs body
-    | Ret rhs0 rhs1 rhs2 <- drev senv rhs
-    , Ret body0 body1 body2 <- drev (SCons (typeOf rhs) senv) body ->
+    | Ret rhs0 rhs1 rhs2 <- drev des rhs
+    , Ret body0 body1 body2 <- drev (DPush des (typeOf rhs) SMerge) body ->
     weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' ->
     Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
         (weakenExpr wbody0' body1)
@@ -267,19 +316,19 @@ drev senv = \case
 
   EPair _ a b
     | Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil)
-        <- retConcat $ drev senv a `SCons` drev senv b `SCons` SNil
+        <- retConcat $ drev des a `SCons` drev des b `SCons` SNil
     , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
     Ret binds
         (EPair ext a1 b1)
         (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
-           (EMReturn (d2e senv) (ENil ext))
+           (EMReturn d2mon (ENil ext))
            (EMBind (ELet ext (EFst ext (EVar ext dt IZ))
                       (weakenExpr (WCopy (wSinks @[_,_])) a2))
                    (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
                       (weakenExpr (WCopy (wSinks @[_,_,_])) b2))))
 
   EFst _ e
-    | Ret e0 e1 e2 <- drev senv e
+    | Ret e0 e1 e2 <- drev des e
     , STPair t1 t2 <- typeOf e ->
     Ret e0
         (EFst ext e1)
@@ -287,40 +336,40 @@ drev senv = \case
            weakenExpr (WCopy WSink) e2)
 
   ESnd _ e
-    | Ret e0 e1 e2 <- drev senv e
+    | Ret e0 e1 e2 <- drev des e
     , STPair t1 t2 <- typeOf e ->
     Ret e0
         (ESnd ext e1)
         (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
            weakenExpr (WCopy WSink) e2)
 
-  ENil _ -> Ret BTop (ENil ext) (EMReturn (d2e senv) (ENil ext))
+  ENil _ -> Ret BTop (ENil ext) (EMReturn d2mon (ENil ext))
 
   EInl _ t2 e
-    | Ret e0 e1 e2 <- drev senv e ->
+    | Ret e0 e1 e2 <- drev des e ->
     Ret e0
         (EInl ext (d1 t2) e1)
         (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
-           (EMReturn (d2e senv) (ENil ext))
+           (EMReturn d2mon (ENil ext))
            (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
               (weakenExpr (WCopy (wSinks @[_,_])) e2)
-              (EError (STEVM (d2e senv) STNil) "inl<-dinr")))
+              (EError (STEVM d2mon STNil) "inl<-dinr")))
 
   EInr _ t1 e
-    | Ret e0 e1 e2 <- drev senv e ->
+    | Ret e0 e1 e2 <- drev des e ->
     Ret e0
         (EInr ext (d1 t1) e1)
         (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
-           (EMReturn (d2e senv) (ENil ext))
+           (EMReturn d2mon (ENil ext))
            (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
-              (EError (STEVM (d2e senv) STNil) "inr<-dinl")
+              (EError (STEVM d2mon STNil) "inr<-dinl")
               (weakenExpr (WCopy (wSinks @[_,_])) e2)))
 
   ECase _ e a b
     | STEither t1 t2 <- typeOf e
-    , Ret e0 e1 e2 <- drev senv e
-    , Ret a0 a1 a2 <- drev (SCons t1 senv) a
-    , Ret b0 b1 b2 <- drev (SCons t2 senv) b
+    , Ret e0 e1 e2 <- drev des e
+    , Ret a0 a1 a2 <- drev (DPush des t1 SMerge) a
+    , Ret b0 b1 b2 <- drev (DPush des t2 SMerge) b
     , TupBinds tapeA collectA reconA <- tupBinds a0
     , TupBinds tapeB collectB reconB <- tupBinds b0
     , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) ->
@@ -341,18 +390,18 @@ drev senv = \case
                       letBinds rebinds $
                         ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
                           EMBind (weakenExpr (WCopy wrebinds) (EMScope a2))
-                                 (EMReturn (d2e senv)
+                                 (EMReturn d2mon
                                     (EInr ext STNil (EInl ext (d2 t2)
                                        (ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ))))))
-                 (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
+                 (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
               (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
-                 (EError (STEVM (d2e senv) (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
+                 (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
                  (case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of
                     TupBindsReconstruct rebinds wrebinds ->
                       letBinds rebinds $
                         ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
                           EMBind (weakenExpr (WCopy wrebinds) (EMScope b2))
-                                 (EMReturn (d2e senv)
+                                 (EMReturn d2mon
                                     (EInr ext STNil (EInr ext (d2 t1)
                                        (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ))))))))
            (weakenExpr (WCopy (wSinks @[_,_,_])) e2))
@@ -360,10 +409,10 @@ drev senv = \case
   EConst _ t val ->
     Ret BTop
         (EConst ext t val)
-        (EMReturn (d2e senv) (ENil ext))
+        (EMReturn d2mon (ENil ext))
 
   EOp _ op e
-    | Ret e0 e1 e2 <- drev senv e ->
+    | Ret e0 e1 e2 <- drev des e ->
     case d2op op of
       Linear d2opfun ->
         Ret e0
@@ -378,3 +427,6 @@ drev senv = \case
                (weakenExpr (WCopy (wSinks @[_,_])) e2))
 
   e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)
+
+  where
+    d2mon = d2e (select SAccum des)
-- 
cgit v1.2.3-70-g09d2