summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-05 22:45:04 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-05 22:45:04 +0100
commita76ec3bcbdea7beaf9066e4ce0b8c5868f571cdb (patch)
tree5caa5c2042d7405b60f97385a0cf8dc66913ccc2
parent889aa1757a0fdf003f38f9d565a4a91660757f38 (diff)
Generate EOneHot in D[EIdx]
This generates a one-hot for the zero-dimensional inner array because indexing one level further to the actual element is too difficult. But this should simplify away fine.
-rw-r--r--src/AST.hs21
-rw-r--r--src/CHAD.hs10
2 files changed, 27 insertions, 4 deletions
diff --git a/src/AST.hs b/src/AST.hs
index b9b10ad..60fc5ad 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -9,8 +9,10 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -379,3 +381,22 @@ ezip a b =
(EVar ext (tTup (sreplicate n tIx)) IZ))
(EIdx ext (EVar ext (STArr n t2) (IS IZ))
(EVar ext (tTup (sreplicate n tIx)) IZ))
+
+arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n)
+arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext)
+ where
+ -- symbolic version of 'invert' in Interpreter
+ go :: forall n m t env proxy. proxy t -> SNat n -> SNat m
+ -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m))
+ go _ SZ _ _ acidx = acidx
+ go p (SS n) m idx acidx
+ | Refl <- lemPlusSuccRight @n @m
+ = ELet ext idx $
+ go p n (SS m)
+ (EFst ext (EVar ext (typeOf idx) IZ))
+ (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ))
+ (weakenExpr WSink acidx))
+
+lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t
+lemAcValArrN _ SZ = Refl
+lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl
diff --git a/src/CHAD.hs b/src/CHAD.hs
index b3e2358..6b0627d 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1156,10 +1156,12 @@ drev des = \case
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS (IS IZ))) $
- ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ)))))
- (EVar ext (d2 eltty) (IS (IS IZ)))
- (EZero eltty)) $
+ (ELet ext (EOneHot (STArr n eltty) n
+ (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ))
+ (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ)
+ SS{} | Refl <- lemAcValArrN (d2 eltty) n ->
+ EPair ext (EVar ext tIxN (IS (IS IZ)))
+ (EUnit ext (EVar ext (d2 eltty) IZ)))) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
EShape _ e