aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
commit4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch)
tree2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Mixed.hs
parent827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff)
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs64
1 files changed, 57 insertions, 7 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 9a77ccb..7293914 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -14,6 +14,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
@@ -44,6 +45,8 @@ import GHC.TypeLits
import qualified GHC.TypeNats as TypeNats
import Unsafe.Coerce (unsafeCoerce)
+import Data.Array.Nested.Internal.Arith
+
-- | Evidence for the constraint @c a@.
data Dict c a where
@@ -120,6 +123,10 @@ foldListX f (x ::% xs) = f x <> foldListX f xs
lengthListX :: ListX sh f -> Int
lengthListX = getSum . foldListX (\_ -> Sum 1)
+snatLengthListX :: ListX sh f -> SNat (Rank sh)
+snatLengthListX ZX = SNat
+snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat
+
showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
showListX f l = showString "[" . go "" l . showString "]"
where
@@ -419,6 +426,26 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int]
ssxIotaFrom _ ZKX = []
ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+type Flatten sh = Flatten' 1 sh
+
+type family Flatten' acc sh where
+ Flatten' acc '[] = Just acc
+ Flatten' acc (Nothing : sh) = Nothing
+ Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
+
+flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh)
+flattenSh = go (SNat @1)
+ where
+ go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go acc ZSX = SKnown acc
+ go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
+ go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh
+
+ goUnknown :: Int -> IShX sh -> Int
+ goUnknown acc ZSX = acc
+ goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
+ goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
+
staticShapeFrom :: IShX sh -> StaticShX sh
staticShapeFrom ZSX = ZKX
staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh
@@ -511,6 +538,10 @@ type family AddMaybe n m where
plusSNat :: SNat n -> SNat m -> SNat (n + m)
plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce
+-- This should be a function in base
+mulSNat :: SNat n -> SNat m -> SNat (n * m)
+mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce
+
smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
@@ -719,17 +750,36 @@ transpose2 ssh1 ssh2 (XArray arr)
sumFull :: (Storable a, Num a) => XArray sh a -> a
sumFull (XArray arr) = S.sumA arr
-sumInner :: forall sh sh' a. (Storable a, Num a)
+sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
-sumInner ssh ssh'
+sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
- = rerank ssh ssh' ZKX (scalar . sumFull)
-
-sumOuter :: forall sh sh' a. (Storable a, Num a)
+ = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ sh'F = flattenSh sh' :$% ZSX
+ ssh'F = staticShapeFrom sh'F
+
+ go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
+ go (XArray arr')
+ | Refl <- lemRankApp ssh ssh'F
+ , let sn = snatLengthListX (let StaticShX l = ssh in l)
+ = XArray (numEltSum1Inner sn arr')
+
+ in go $
+ transpose2 ssh'F ssh $
+ reshapePartial ssh' ssh sh'F $
+ transpose2 ssh ssh' $
+ arr
+
+sumOuter :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
-sumOuter ssh ssh'
+sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
- = sumInner ssh' ssh . transpose2 ssh ssh'
+ = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ shF = flattenSh sh :$% ZSX
+ in sumInner ssh' (staticShapeFrom shF) $
+ transpose2 (staticShapeFrom shF) ssh' $
+ reshapePartial ssh ssh' shF $
+ arr
fromListOuter :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a