summaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r--src/ForwardAD/DualNumbers.hs18
1 files changed, 10 insertions, 8 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index a93b8e6..f02b93e 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -171,8 +171,10 @@ dfwdDN = \case
(EConst ext t x)
EIdx0 _ e -> EIdx0 ext (dfwdDN e)
EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
- EIdx _ n a b
- | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b)
+ EIdx _ a b
+ | STArr n _ <- typeOf a
+ , Refl <- dnPreservesTupIx n
+ -> EIdx ext (dfwdDN a) (dfwdDN b)
EShape _ e
| Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
@@ -191,8 +193,8 @@ emap f arr =
let STArr n t = typeOf arr
in ELet ext arr $
EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) f
ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
@@ -202,7 +204,7 @@ ezip a b =
in ELet ext a $
ELet ext (weakenExpr WSink b) $
EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- (EIdx ext n (EVar ext (STArr n t2) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+ EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext (EVar ext (STArr n t2) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))