diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 22:45:04 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 22:45:04 +0100 | 
| commit | a76ec3bcbdea7beaf9066e4ce0b8c5868f571cdb (patch) | |
| tree | 5caa5c2042d7405b60f97385a0cf8dc66913ccc2 | |
| parent | 889aa1757a0fdf003f38f9d565a4a91660757f38 (diff) | |
Generate EOneHot in D[EIdx]
This generates a one-hot for the zero-dimensional inner array because
indexing one level further to the actual element is too difficult. But
this should simplify away fine.
| -rw-r--r-- | src/AST.hs | 21 | ||||
| -rw-r--r-- | src/CHAD.hs | 10 | 
2 files changed, 27 insertions, 4 deletions
| @@ -9,8 +9,10 @@  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} @@ -379,3 +381,22 @@ ezip a b =                               (EVar ext (tTup (sreplicate n tIx)) IZ))                     (EIdx ext (EVar ext (STArr n t2) (IS IZ))                               (EVar ext (tTup (sreplicate n tIx)) IZ)) + +arrIdxToAcIdx :: proxy t -> SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr n t) n) +arrIdxToAcIdx = \p (n :: SNat n) e -> case lemPlusZero @n of Refl -> go p n SZ e (ENil ext) +  where +    -- symbolic version of 'invert' in Interpreter +    go :: forall n m t env proxy. proxy t -> SNat n -> SNat m +       -> Ex env (Tup (Replicate n TIx)) -> Ex env (AcIdx (TArr m t) m) -> Ex env (AcIdx (TArr (n + m) t) (n + m)) +    go _ SZ _ _ acidx = acidx +    go p (SS n) m idx acidx +      | Refl <- lemPlusSuccRight @n @m +      = ELet ext idx $ +          go p n (SS m) +             (EFst ext (EVar ext (typeOf idx) IZ)) +             (EPair ext (ESnd ext (EVar ext (typeOf idx) IZ)) +                        (weakenExpr WSink acidx)) + +lemAcValArrN :: proxy t -> SNat n -> AcValArr n t n :~: TArr Z t +lemAcValArrN _ SZ = Refl +lemAcValArrN p (SS n) | Refl <- lemAcValArrN p n = Refl diff --git a/src/CHAD.hs b/src/CHAD.hs index b3e2358..6b0627d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1156,10 +1156,12 @@ drev des = \case          (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))                    (EVar ext (tTup (sreplicate n tIx)) IZ))          sub -        (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS (IS IZ))) $ -                     ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ))))) -                       (EVar ext (d2 eltty) (IS (IS IZ))) -                       (EZero eltty)) $ +        (ELet ext (EOneHot (STArr n eltty) n +                           (arrIdxToAcIdx (d2 eltty) n $ EVar ext tIxN (IS IZ)) +                           (case n of SZ -> EUnit ext (EVar ext (d2 eltty) IZ) +                                      SS{} | Refl <- lemAcValArrN (d2 eltty) n -> +                                        EPair ext (EVar ext tIxN (IS (IS IZ))) +                                                  (EUnit ext (EVar ext (d2 eltty) IZ)))) $           weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)    EShape _ e | 
