diff options
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index a93b8e6..f02b93e 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -171,8 +171,10 @@ dfwdDN = \case (EConst ext t x) EIdx0 _ e -> EIdx0 ext (dfwdDN e) EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) - EIdx _ n a b - | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b) + EIdx _ a b + | STArr n _ <- typeOf a + , Refl <- dnPreservesTupIx n + -> EIdx ext (dfwdDN a) (dfwdDN b) EShape _ e | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) @@ -191,8 +193,8 @@ emap f arr = let STArr n t = typeOf arr in ELet ext arr $ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) @@ -202,7 +204,7 @@ ezip a b = in ELet ext a $ ELet ext (weakenExpr WSink b) $ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) + EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (EVar ext (STArr n t2) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) |