From a76ec3bcbdea7beaf9066e4ce0b8c5868f571cdb Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 5 Nov 2024 22:45:04 +0100 Subject: Generate EOneHot in D[EIdx] This generates a one-hot for the zero-dimensional inner array because indexing one level further to the actual element is too difficult. But this should simplify away fine. --- src/AST.hs | 21 +++++++++++++++++++++ src/CHAD.hs | 10 ++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/AST.hs b/src/AST.hs index b9b10ad..60fc5ad 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -9,8 +9,10 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -379,3 +381,22 @@ ezip a b = (EVar ext (tTup (sreplicate n tIx)) IZ)) (EIdx ext (EVar ext (STArr n t2) (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + +arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n) +arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext) + where + -- symbolic version of 'invert' in Interpreter + go :: forall n m t env proxy. proxy t -> SNat n -> SNat m + -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m)) + go _ SZ _ _ acidx = acidx + go p (SS n) m idx acidx + | Refl <- lemPlusSuccRight @n @m + = ELet ext idx $ + go p n (SS m) + (EFst ext (EVar ext (typeOf idx) IZ)) + (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ)) + (weakenExpr WSink acidx)) + +lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t +lemAcValArrN _ SZ = Refl +lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl diff --git a/src/CHAD.hs b/src/CHAD.hs index b3e2358..6b0627d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1156,10 +1156,12 @@ drev des = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS (IS IZ))) $ - ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ))))) - (EVar ext (d2 eltty) (IS (IS IZ))) - (EZero eltty)) $ + (ELet ext (EOneHot (STArr n eltty) n + (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ)) + (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ) + SS{} | Refl <- lemAcValArrN (d2 eltty) n -> + EPair ext (EVar ext tIxN (IS (IS IZ))) + (EUnit ext (EVar ext (d2 eltty) IZ)))) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e -- cgit v1.2.3-70-g09d2