From 281229b7bf307132a428dde1b171e1db86637238 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 11 Nov 2024 09:35:35 +0100
Subject: Make EBuild derivative aware of zero cotangent arrays

---
 src/AST.hs  | 14 ++++++++++++++
 src/CHAD.hs | 56 +++++++++++++++++++++++++++++---------------------------
 2 files changed, 43 insertions(+), 27 deletions(-)

diff --git a/src/AST.hs b/src/AST.hs
index e7dde90..9ad0d4d 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -408,6 +408,20 @@ ezip a b =
                    (EIdx ext (EVar ext (STArr n t2) (IS IZ))
                              (EVar ext (tTup (sreplicate n tIx)) IZ))
 
+eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
+eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
+
+-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
+eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eshapeEmpty SZ _ = EConst ext STBool False
+eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
+eshapeEmpty (SS n) e =
+  ELet ext e $
+    EOp ext OAnd (EPair ext
+      (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
+                                      (EConst ext STI64 0)))
+      (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
 arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n)
 arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext)
   where
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 8b9f17a..45fcc82 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1016,33 +1016,35 @@ drev des = \case
               (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
         (subenvCompose subMergeUsed proSub)
         (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
-         ESnd ext $
-          uninvertTup (d2e envPro) (STArr ndim STNil) $
-            makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
-              EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
-                -- the cotangent for this element
-                ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
-                                   (EVar ext shty IZ)) $
-                -- the tape for this element
-                ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
-                                   (EVar ext shty (IS IZ))) $
-                let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
-                in letBinds rebinds $
-                     weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
-                                           &. #pro (d2ace envPro)
-                                           &. #etape (subList (bindingsBinds e0) subtapeE)
-                                           &. #prerebinds prerebinds
-                                           &. #tape (tapety `SCons` SNil)
-                                           &. #ix (shty `SCons` SNil)
-                                           &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
-                                           &. #tapearr (STArr ndim tapety `SCons` SNil)
-                                           &. #sh (shty `SCons` SNil)
-                                           &. #d2acUsed (d2ace (select SAccum usedDes))
-                                           &. #d2acEnv (d2ace (select SAccum des)))
-                                          (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
-                                          ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
-                                 .> wPro (subList (bindingsBinds e0) subtapeE))
-                                e2)
+         eif (eshapeEmpty ndim (EShape ext (EVar ext (STArr ndim (d2 eltty)) IZ)))
+           (zeroTup envPro)
+           (ESnd ext $
+              uninvertTup (d2e envPro) (STArr ndim STNil) $
+                makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
+                  EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+                    -- the cotangent for this element
+                    ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
+                                       (EVar ext shty IZ)) $
+                    -- the tape for this element
+                    ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+                                       (EVar ext shty (IS IZ))) $
+                    let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
+                    in letBinds rebinds $
+                         weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
+                                               &. #pro (d2ace envPro)
+                                               &. #etape (subList (bindingsBinds e0) subtapeE)
+                                               &. #prerebinds prerebinds
+                                               &. #tape (tapety `SCons` SNil)
+                                               &. #ix (shty `SCons` SNil)
+                                               &. #darr (STArr ndim (d2 eltty) `SCons` SNil)
+                                               &. #tapearr (STArr ndim tapety `SCons` SNil)
+                                               &. #sh (shty `SCons` SNil)
+                                               &. #d2acUsed (d2ace (select SAccum usedDes))
+                                               &. #d2acEnv (d2ace (select SAccum des)))
+                                              (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+                                              ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+                                     .> wPro (subList (bindingsBinds e0) subtapeE))
+                                    e2))
     }}
 
   EUnit _ e
-- 
cgit v1.2.3-70-g09d2