diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 22:45:04 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 22:45:04 +0100 |
commit | a76ec3bcbdea7beaf9066e4ce0b8c5868f571cdb (patch) | |
tree | 5caa5c2042d7405b60f97385a0cf8dc66913ccc2 /src/CHAD.hs | |
parent | 889aa1757a0fdf003f38f9d565a4a91660757f38 (diff) |
Generate EOneHot in D[EIdx]
This generates a one-hot for the zero-dimensional inner array because
indexing one level further to the actual element is too difficult. But
this should simplify away fine.
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index b3e2358..6b0627d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1156,10 +1156,12 @@ drev des = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS (IS IZ))) $ - ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ))))) - (EVar ext (d2 eltty) (IS (IS IZ))) - (EZero eltty)) $ + (ELet ext (EOneHot (STArr n eltty) n + (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ)) + (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ) + SS{} | Refl <- lemAcValArrN (d2 eltty) n -> + EPair ext (EVar ext tIxN (IS (IS IZ))) + (EUnit ext (EVar ext (d2 eltty) IZ)))) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e |