summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-11 09:35:35 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-11 09:35:35 +0100
commit281229b7bf307132a428dde1b171e1db86637238 (patch)
treef5fbe70ef80132ac1ea084cd5731998a9a55f677 /src/AST.hs
parenta46f53695d1dfab8834c7cc52707c0c0bb9b8ba0 (diff)
Make EBuild derivative aware of zero cotangent arrays
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs14
1 files changed, 14 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index e7dde90..9ad0d4d 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -408,6 +408,20 @@ ezip a b =
(EIdx ext (EVar ext (STArr n t2) (IS IZ))
(EVar ext (tTup (sreplicate n tIx)) IZ))
+eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
+eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
+
+-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
+eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
+eshapeEmpty SZ _ = EConst ext STBool False
+eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
+eshapeEmpty (SS n) e =
+ ELet ext e $
+ EOp ext OAnd (EPair ext
+ (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
+ (EConst ext STI64 0)))
+ (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n)
arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext)
where