diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 13:47:18 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-23 13:47:18 +0200 |
commit | 4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch) | |
tree | 2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Mixed.hs | |
parent | 827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff) |
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 64 |
1 files changed, 57 insertions, 7 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 9a77ccb..7293914 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -14,6 +14,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -44,6 +45,8 @@ import GHC.TypeLits import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce (unsafeCoerce) +import Data.Array.Nested.Internal.Arith + -- | Evidence for the constraint @c a@. data Dict c a where @@ -120,6 +123,10 @@ foldListX f (x ::% xs) = f x <> foldListX f xs lengthListX :: ListX sh f -> Int lengthListX = getSum . foldListX (\_ -> Sum 1) +snatLengthListX :: ListX sh f -> SNat (Rank sh) +snatLengthListX ZX = SNat +snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat + showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS showListX f l = showString "[" . go "" l . showString "]" where @@ -419,6 +426,26 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKX = [] ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) +flattenSh = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + staticShapeFrom :: IShX sh -> StaticShX sh staticShapeFrom ZSX = ZKX staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh @@ -511,6 +538,10 @@ type family AddMaybe n m where plusSNat :: SNat n -> SNat m -> SNat (n + m) plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce +-- This should be a function in base +mulSNat :: SNat n -> SNat m -> SNat (n * m) +mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce + smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) @@ -719,17 +750,36 @@ transpose2 ssh1 ssh2 (XArray arr) sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr -sumInner :: forall sh sh' a. (Storable a, Num a) +sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' +sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = rerank ssh ssh' ZKX (scalar . sumFull) - -sumOuter :: forall sh sh' a. (Storable a, Num a) + = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = flattenSh sh' :$% ZSX + ssh'F = staticShapeFrom sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = snatLengthListX (let StaticShX l = ssh in l) + = XArray (numEltSum1Inner sn arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' +sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = sumInner ssh' ssh . transpose2 ssh ssh' + = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = flattenSh sh :$% ZSX + in sumInner ssh' (staticShapeFrom shF) $ + transpose2 (staticShapeFrom shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a |