aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs190
1 files changed, 117 insertions, 73 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index bf14bf5..c999853 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -1,9 +1,11 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -14,9 +16,11 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -31,14 +35,17 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
-import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (withDict)
+import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
+#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0)
+import GHC.TypeLits.Orphans ()
+#endif
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
@@ -100,21 +107,24 @@ listxEqual (n ::% sh) (m ::% sh')
= Just Refl
listxEqual _ _ = Nothing
+{-# INLINE listxFmap #-}
listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
listxFmap _ ZX = ZX
listxFmap f (x ::% xs) = f x ::% listxFmap f xs
-listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
-listxFold _ ZX = mempty
-listxFold f (x ::% xs) = f x <> listxFold f xs
+{-# INLINE listxFoldMap #-}
+listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+listxFoldMap _ ZX = mempty
+listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs
listxLength :: ListX sh f -> Int
-listxLength = getSum . listxFold (\_ -> Sum 1)
+listxLength = getSum . listxFoldMap (\_ -> Sum 1)
listxRank :: ListX sh f -> SNat (Rank sh)
listxRank ZX = SNat
listxRank (_ ::% l) | SNat <- listxRank l = SNat
+{-# INLINE listxShow #-}
listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
where
@@ -132,9 +142,13 @@ listxFromList topssh topl = go topssh topl
++ show (ssxLength topssh) ++ ", list has length "
++ show (length topl) ++ ")"
-listxToList :: ListX sh' (Const i) -> [i]
-listxToList ZX = []
-listxToList (Const i ::% is) = i : listxToList is
+{-# INLINEABLE listxToList #-}
+listxToList :: ListX sh (Const i) -> [i]
+listxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListX sh (Const i) -> is
+ go ZX = nil
+ go (Const i ::% is) = i `cons` go is
+ in go list)
listxHead :: ListX (mn ': sh) f -> f mn
listxHead (i ::% _) = i
@@ -146,9 +160,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
listxAppend ZX idx' = idx'
listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f
+listxDrop ZX long = long
+listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long'
listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
@@ -160,19 +174,18 @@ listxLast (x ::% ZX) = x
listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g)
listxZip ZX ZX = ZX
-listxZip (i ::% irest) (j ::% jrest) =
- Pair i j ::% listxZip irest jrest
+listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest
+{-# INLINE listxZipWith #-}
listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g
-> ListX sh h
listxZipWith _ ZX ZX = ZX
-listxZipWith f (i ::% is) (j ::% js) =
- f i j ::% listxZipWith f is js
+listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js
-- * Mixed indices
--- | This is a newtype over 'ListX'.
+-- | An index into a mixed-typed array.
type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
@@ -191,6 +204,8 @@ infixr 3 :.%
{-# COMPLETE ZIX, (:.%) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxX sh = IxX sh Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -201,10 +216,18 @@ instance Show i => Show (IxX sh i) where
#endif
instance Functor (IxX sh) where
+ {-# INLINE fmap #-}
fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
instance Foldable (IxX sh) where
- foldMap f (IxX l) = listxFold (f . getConst) l
+ {-# INLINE foldMap #-}
+ foldMap f (IxX l) = listxFoldMap (f . getConst) l
+ {-# INLINE foldr #-}
+ foldr _ z ZIX = z
+ foldr f z (x :.% xs) = f x (foldr f z xs)
+ toList = ixxToList
+ null ZIX = False
+ null _ = True
instance NFData i => NFData (IxX sh i)
@@ -225,6 +248,10 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)
+{-# INLINEABLE ixxToList #-}
+ixxToList :: forall sh i. IxX sh i -> [i]
+ixxToList = coerce (listxToList @_ @i)
+
ixxHead :: IxX (n : sh) i -> i
ixxHead (IxX list) = getConst (listxHead list)
@@ -234,7 +261,7 @@ ixxTail (IxX list) = IxX (listxTail list)
ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
ixxAppend = coerce (listxAppend @_ @(Const i))
-ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i
ixxDrop = coerce (listxDrop @(Const i) @(Const i))
ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
@@ -243,28 +270,20 @@ ixxInit = coerce (listxInit @(Const i))
ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))
+ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i
+ixxCast ZKX ZIX = ZIX
+ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx
+ixxCast _ _ = error "ixxCast: ranks don't match"
+
ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)
ixxZip ZIX ZIX = ZIX
ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js
+{-# INLINE ixxZipWith #-}
ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k
ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
-ixxFromLinear :: IShX sh -> Int -> IIxX sh
-ixxFromLinear = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
- where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$% sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSMayNat' n
- in (locali :.% idx, upi)
-
ixxToLinear :: IShX sh -> IIxX sh -> Int
ixxToLinear = \sh i -> fst (go sh i)
where
@@ -294,6 +313,7 @@ instance TestEquality f => TestEquality (SMayNat i f) where
testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
testEquality _ _ = Nothing
+{-# INLINE fromSMayNat #-}
fromSMayNat :: (n ~ Nothing => i -> r)
-> (forall m. n ~ Just m => f m -> r)
-> SMayNat i f n -> r
@@ -343,6 +363,7 @@ instance Show i => Show (ShX sh i) where
#endif
instance Functor (ShX sh) where
+ {-# INLINE fmap #-}
fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
instance NFData i => NFData (ShX sh i) where
@@ -390,10 +411,10 @@ shxSize :: IShX sh -> Int
shxSize ZSX = 1
shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
-shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
+shxFromList :: StaticShX sh -> [Int] -> IShX sh
shxFromList topssh topl = go topssh topl
where
- go :: StaticShX sh' -> [Int] -> ShX sh' Int
+ go :: StaticShX sh' -> [Int] -> IShX sh'
go ZKX [] = ZSX
go (SKnown sn :!% sh) (i : is)
| i == fromSNat' sn = SKnown sn :$% go sh is
@@ -404,15 +425,26 @@ shxFromList topssh topl = go topssh topl
++ show (ssxLength topssh) ++ ", list has length "
++ show (length topl) ++ ")"
+{-# INLINEABLE shxToList #-}
shxToList :: IShX sh -> [Int]
-shxToList ZSX = []
-shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+shxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: IShX sh -> is
+ go ZSX = nil
+ go (smn :$% sh) = fromSMayNat' smn `cons` go sh
+ in go list)
+
+shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
+shxFromSSX ZKX = ZSX
+shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh))
+ | Refl <- lemMapJustCons @sh Refl
+ = SKnown n :$% shxFromSSX sh
+shxFromSSX (SUnknown _ :!% _) = error "unreachable"
-- | This may fail if @sh@ has @Nothing@s in it.
-shxFromSSX' :: StaticShX sh -> Maybe (IShX sh)
-shxFromSSX' ZKX = Just ZSX
-shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh
-shxFromSSX' (SUnknown _ :!% _) = Nothing
+shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i)
+shxFromSSX2 ZKX = Just ZSX
+shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
+shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
@@ -423,13 +455,13 @@ shxHead (ShX list) = listxHead list
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
-shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
-shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
@@ -438,13 +470,11 @@ shxInit = coerce (listxInit @(SMayNat i SNat))
shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
shxLast = coerce (listxLast @(SMayNat i SNat))
-shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shxTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
+{-# INLINE shxZipWith #-}
shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
-> ShX sh i -> ShX sh j -> ShX sh k
shxZipWith _ ZSX ZSX = ZSX
@@ -456,28 +486,37 @@ shxCompleteZeros ZKX = ZSX
shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
-shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
shxSplitApp _ ZKX idx = (ZSX, idx)
shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
shxEnum :: IShX sh -> [IIxX sh]
-shxEnum = \sh -> go sh id []
+shxEnum = shxEnum'
+
+{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site
+shxEnum' :: Num i => IShX sh -> [IxX sh i]
+shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]]
where
- go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZSX f = (f ZIX :)
- go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+ suffixes = drop 1 (scanr (*) 1 (shxToList sh))
-shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
-shxCast ZSX ZKX = Just ZSX
-shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
-shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
+ fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i
+ fromLin ZSX _ _ = ZIX
+ fromLin (_ :$% sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.% fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
+shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh')
+shxCast ZKX ZSX = Just ZSX
+shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh
+shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh
+shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh
shxCast _ _ = Nothing
-- | Partial version of 'shxCast'.
-shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
-shxCast' sh ssh = case shxCast sh ssh of
+shxCast' :: StaticShX sh' -> IShX sh -> IShX sh'
+shxCast' ssh sh = case shxCast ssh sh of
Just sh' -> sh'
Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
@@ -537,13 +576,13 @@ ssxHead (StaticShX list) = listxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropSSX :: forall sh sh'. StaticShX (sh ++ sh') -> StaticShX sh -> StaticShX sh'
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
-ssxDropSh :: forall sh sh' i. StaticShX (sh ++ sh') -> ShX sh i -> StaticShX sh'
+ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
@@ -555,20 +594,20 @@ ssxLast = coerce (listxLast @(SMayNat () SNat))
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ | Refl <- lemReplicateSucc @(Nothing @Nat) n
= SUnknown () :!% ssxReplicate n
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom :: StaticShX sh -> Int -> [Int]
+ssxIotaFrom ZKX _ = []
+ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1)
-ssxFromShX :: IShX sh -> StaticShX sh
+ssxFromShX :: ShX sh i -> StaticShX sh
ssxFromShX ZSX = ZKX
ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh
ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n
-- | Evidence for the static part of a shape. This pops up only when you are
@@ -580,7 +619,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK
instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
-withKnownShX k = withDict @(KnownShX sh) k
+withKnownShX = withDict @(KnownShX sh)
-- * Flattening
@@ -632,3 +671,8 @@ instance KnownShX sh => IsList (ShX sh Int) where
type Item (ShX sh Int) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
+
+-- This needs to be at the bottom of the file to not split the file into
+-- pieces; some of the shape/index stuff refers to StaticShX.
+$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |])
+{-# INLINEABLE ixxFromLinear #-}