aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs36
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs166
2 files changed, 139 insertions, 63 deletions
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index f50f671..11a8ffb 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -25,6 +26,7 @@ import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
+import Data.Type.Equality
import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
@@ -34,13 +36,13 @@ import GHC.TypeLits
import Data.Foldable (toList)
#endif
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Types
-import Data.Array.XArray (XArray(..))
+import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
+import Data.Array.Nested.Types
import Data.Array.Strided.Arith
+import Data.Array.XArray (XArray(..))
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -95,8 +97,8 @@ instance Elt a => Elt (Ranked n a) where
mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat m -> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Just m : sh) (Ranked n a)
+ mfromListOuterSN sn l = M_Ranked (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
mtoListOuter (M_Ranked arr) =
@@ -141,7 +143,7 @@ instance Elt a => Elt (Ranked n a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
@@ -177,10 +179,10 @@ instance Elt a => Elt (Ranked n a) where
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownReplicate (SNat @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (SNat @n)
@@ -208,15 +210,15 @@ instance (NumElt a, PrimElt a) => Num (Ranked n a) where
negate = liftRanked1 negate
abs = liftRanked1 abs
signum = liftRanked1 signum
- fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal"
+ fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal"
+ fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicatePrim"
recip = liftRanked1 recip
(/) = liftRanked2 (/)
instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal"
+ pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicatePrim"
exp = liftRanked1 exp
log = liftRanked1 log
sqrt = liftRanked1 sqrt
@@ -252,3 +254,15 @@ rshape (Ranked arr) = shrFromShX2 (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
+shrFromShX ZSX = ZSR
+shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
+
+-- Needed already here, but re-exported in Data.Array.Nested.Convert.
+-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
+shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
+shrFromShX2 sh
+ | Refl <- lemRankReplicate (Proxy @n)
+ = shrFromShX sh
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index c0c4f17..6d61bd5 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -1,13 +1,14 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -18,9 +19,11 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -33,17 +36,20 @@ import Data.Foldable qualified as Foldable
import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
+import GHC.Exts (Int(..), Int#, quotRemInt#, build)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import GHC.TypeNats qualified as TN
-import Data.Array.Mixed.Lemmas
-import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Types
+-- * Ranked lists
+
type role ListR nominal representational
type ListR :: Nat -> Type -> Type
data ListR n i where
@@ -51,8 +57,6 @@ data ListR n i where
(:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
infixr 3 :::
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -66,6 +70,22 @@ instance NFData i => NFData (ListR n i) where
rnf ZR = ()
rnf (x ::: l) = rnf x `seq` rnf l
+instance Functor (ListR n) where
+ {-# INLINE fmap #-}
+ fmap _ ZR = ZR
+ fmap f (x ::: xs) = f x ::: fmap f xs
+
+instance Foldable (ListR n) where
+ {-# INLINE foldMap #-}
+ foldMap _ ZR = mempty
+ foldMap f (x ::: xs) = f x <> foldMap f xs
+ {-# INLINE foldr #-}
+ foldr _ z ZR = z
+ foldr f z (x ::: xs) = f x (foldr f z xs)
+ toList = listrToList
+ null ZR = False
+ null _ = True
+
data UnconsListRRes i n1 =
forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
@@ -90,6 +110,7 @@ listrEqual (i ::: sh) (j ::: sh')
= Just Refl
listrEqual _ _ = Nothing
+{-# INLINE listrShow #-}
listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
where
@@ -108,27 +129,41 @@ listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
listrAppend ZR sh = sh
listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+listrFromList :: SNat n -> [i] -> ListR n i
+listrFromList topsn topl = go topsn topl
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error $ "listrFromList: Mismatched list length (type says "
+ ++ show (fromSNat topsn) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+{-# INLINEABLE listrToList #-}
+listrToList :: ListR n i -> [i]
+listrToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListR n i -> is
+ go ZR = nil
+ go (i ::: is) = i `cons` go is
+ in go list)
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
listrInit :: ListR (n + 1) i -> ListR n i
listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
listrLast :: ListR (n + 1) i -> i
listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
+
+-- | Performs a runtime check that the lengths are identical.
+listrCast :: SNat n' -> ListR n i -> ListR n' i
+listrCast = listrCastWithName "listrCast"
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
@@ -140,6 +175,7 @@ listrZip ZR ZR = ZR
listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+{-# INLINE listrZipWith #-}
listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
listrZipWith _ ZR ZR = ZR
listrZipWith f (i ::: irest) (j ::: jrest) =
@@ -149,13 +185,15 @@ listrZipWith _ _ _ =
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
- listrFromList perm $ \sperm ->
- case (listrRank sperm, listrRank sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case listrRank sh of { shlen@SNat ->
+ let sperm = listrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
where
listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
listrSplitAt SZ sh = (ZR, sh)
@@ -172,6 +210,8 @@ listrPermutePrefix = \perm sh ->
GTI -> error "listrPermutePrefix: Index in permutation out of range"
+-- * Ranked indices
+
-- | An index into a rank-typed array.
type role IxR nominal representational
type IxR :: Nat -> Type -> Type
@@ -192,6 +232,8 @@ infixr 3 :.:
{-# COMPLETE ZIR, (:.:) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxR n = IxR n Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -213,15 +255,12 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
-ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
-ixrFromIxX ZIX = ZIR
-ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
+ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
+ixrFromList = coerce (listrFromList @_ @i)
-ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
-ixxFromIxR ZIR = ZIX
-ixxFromIxR (n :.: (idx :: IxR m i)) =
- castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixxFromIxR idx)
+{-# INLINEABLE ixrToList #-}
+ixrToList :: forall n i. IxR n i -> [i]
+ixrToList = coerce (listrToList @_ @i)
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -235,12 +274,17 @@ ixrInit (IxR list) = IxR (listrInit list)
ixrLast :: IxR (n + 1) i -> i
ixrLast (IxR list) = listrLast list
+-- | Performs a runtime check that the lengths are identical.
+ixrCast :: SNat n' -> IxR n i -> IxR n' i
+ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx)
+
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+{-# INLINE ixrZipWith #-}
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
@@ -248,6 +292,8 @@ ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
+-- * Ranked shapes
+
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
@@ -278,22 +324,6 @@ instance Show i => Show (ShR n i) where
instance NFData i => NFData (ShR sh i)
-shrFromShX :: forall sh. IShX sh -> IShR (Rank sh)
-shrFromShX ZSX = ZSR
-shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx
-
--- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'.
-shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n
-shrFromShX2 sh
- | Refl <- lemRankReplicate (Proxy @n)
- = shrFromShX sh
-
-shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
-shxFromShR ZSR = ZSX
-shxFromShR (n :$: (idx :: ShR m i)) =
- castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shxFromShR idx)
-
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
@@ -317,6 +347,13 @@ shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh
+shrFromList :: forall n i. SNat n -> [i] -> ShR n i
+shrFromList = coerce (listrFromList @_ @i)
+
+{-# INLINEABLE shrToList #-}
+shrToList :: forall n i. ShR n i -> [i]
+shrToList = coerce (listrToList @_ @i)
+
shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list
@@ -329,30 +366,44 @@ shrInit (ShR list) = ShR (listrInit list)
shrLast :: ShR (n + 1) i -> i
shrLast (ShR list) = listrLast list
+-- | Performs a runtime check that the lengths are identical.
+shrCast :: SNat n' -> ShR n i -> ShR n' i
+shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrEnum :: IShR sh -> [IIxR sh]
+shrEnum = shrEnum'
+
+{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site
+shrEnum' :: Num i => IShR sh -> [IxR sh i]
+shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shrToList sh))
+
+ fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
+ fromLin ZSR _ _ = ZIR
+ fromLin (_ :$: sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh'
+ in fromIntegral (I# q#) :.: fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
-- | Untyped: length is checked at runtime.
instance KnownNat n => IsList (ListR n i) where
type Item (ListR n i) = i
- fromList topl = go (SNat @n) topl
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error $ "IsList(ListR): Mismatched list length (type says "
- ++ show (fromSNat (SNat @n)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = listrFromList (SNat @n)
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
@@ -366,3 +417,14 @@ instance KnownNat n => IsList (ShR n i) where
type Item (ShR n i) = i
fromList = ShR . IsList.fromList
toList = Foldable.toList
+
+
+-- * Internal helper functions
+
+listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
+listrCastWithName _ SZ ZR = ZR
+listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
+
+$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
+{-# INLINEABLE ixrFromLinear #-}