aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs27
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs168
2 files changed, 136 insertions, 59 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index ddd44bf..98f1241 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -90,8 +90,8 @@ instance Elt a => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a)
+ mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
mtoListOuter (M_Shaped arr)
@@ -136,7 +136,7 @@ instance Elt a => Elt (Shaped sh a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
@@ -172,10 +172,10 @@ instance Elt a => Elt (Shaped sh a) where
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
@@ -203,15 +203,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
negate = liftShaped1 negate
abs = liftShaped1 abs
signum = liftShaped1 signum
- fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim"
recip = liftShaped1 recip
(/) = liftShaped2 (/)
instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim"
exp = liftShaped1 exp
log = liftShaped1 log
sqrt = liftShaped1 sqrt
@@ -246,15 +246,10 @@ sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
-- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
shsFromShX (SUnknown _ :$% _) = error "impossible"
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index fbfc7f5..0d90e91 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -1,13 +1,12 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -18,9 +17,11 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
+{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
@@ -37,17 +38,22 @@ import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
import Data.Proxy
import Data.Type.Equality
-import GHC.Exts (withDict)
+import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build)
import GHC.Generics (Generic)
import GHC.IsList (IsList)
import GHC.IsList qualified as IsList
import GHC.TypeLits
import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Mixed.Shape.Internal
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
+-- * Shaped lists
+
+-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be
+-- removed in a future release.
type role ListS nominal representational
type ListS :: [Nat] -> (Nat -> Type) -> Type
data ListS sh f where
@@ -98,13 +104,15 @@ listsEqual (n ::$ sh) (m ::$ sh')
= Just Refl
listsEqual _ _ = Nothing
+{-# INLINE listsFmap #-}
listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
-listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-listsFold _ ZS = mempty
-listsFold f (x ::$ xs) = f x <> listsFold f xs
+{-# INLINE listsFoldMap #-}
+listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFoldMap _ ZS = mempty
+listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs
listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
listsShow f l = showString "[" . go "" l . showString "]"
@@ -114,15 +122,29 @@ listsShow f l = showString "[" . go "" l . showString "]"
go prefix (x ::$ xs) = showString prefix . f x . go "," xs
listsLength :: ListS sh f -> Int
-listsLength = getSum . listsFold (\_ -> Sum 1)
+listsLength = getSum . listsFoldMap (\_ -> Sum 1)
listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+listsFromList :: ShS sh -> [i] -> ListS sh (Const i)
+listsFromList topsh topl = go topsh topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "listsFromList: Mismatched list length (type says "
+ ++ show (shsLength topsh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+{-# INLINEABLE listsToList #-}
listsToList :: ListS sh (Const i) -> [i]
-listsToList ZS = []
-listsToList (Const i ::$ is) = i : listsToList is
+listsToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+ let go :: ListS sh (Const i) -> is
+ go ZS = nil
+ go (Const i ::$ is) = i `cons` go is
+ in go list)
listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i
@@ -144,14 +166,13 @@ listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g)
listsZip ZS ZS = ZS
-listsZip (i ::$ is) (j ::$ js) =
- Fun.Pair i j ::$ listsZip is js
+listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js
+{-# INLINE listsZipWith #-}
listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g
-> ListS sh h
listsZipWith _ ZS ZS = ZS
-listsZipWith f (i ::$ is) (j ::$ js) =
- f i j ::$ listsZipWith f is js
+listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js
listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
@@ -180,11 +201,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape"
listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+-- * Shaped indices
-- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\").
type role IxS nominal representational
type IxS :: [Nat] -> Type -> Type
newtype IxS sh i = IxS (ListS sh (Const i))
@@ -193,6 +212,8 @@ newtype IxS sh i = IxS (ListS sh (Const i))
pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
pattern ZIS = IxS ZS
+-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be
+-- removed in a future release.
pattern (:.$)
:: forall {sh1} {i}.
forall n sh. (KnownNat n, n : sh ~ sh1)
@@ -203,6 +224,8 @@ infixr 3 :.$
{-# COMPLETE ZIS, (:.$) #-}
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\").
type IIxS sh = IxS sh Int
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
@@ -213,10 +236,18 @@ instance Show i => Show (IxS sh i) where
#endif
instance Functor (IxS sh) where
+ {-# INLINE fmap #-}
fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
instance Foldable (IxS sh) where
- foldMap f (IxS l) = listsFold (f . getConst) l
+ {-# INLINE foldMap #-}
+ foldMap f (IxS l) = listsFoldMap (f . getConst) l
+ {-# INLINE foldr #-}
+ foldr _ z ZIS = z
+ foldr f z (x :.$ xs) = f x (foldr f z xs)
+ toList = ixsToList
+ null ZIS = False
+ null _ = True
instance NFData i => NFData (IxS sh i)
@@ -226,6 +257,13 @@ ixsLength (IxS l) = listsLength l
ixsRank :: IxS sh i -> SNat (Rank sh)
ixsRank (IxS l) = listsRank l
+ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
+ixsFromList = coerce (listsFromList @_ @i)
+
+{-# INLINEABLE ixsToList #-}
+ixsToList :: forall sh i. IxS sh i -> [i]
+ixsToList = coerce (listsToList @_ @i)
+
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
@@ -242,14 +280,21 @@ ixsInit (IxS list) = IxS (listsInit list)
ixsLast :: IxS (n : sh) i -> i
ixsLast (IxS list) = getConst (listsLast list)
+-- TODO: this takes a ShS because there are KnownNats inside IxS.
+ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i
+ixsCast ZSS ZIS = ZIS
+ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx
+ixsCast _ _ = error "ixsCast: ranks don't match"
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
-ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
+ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j)
ixsZip ZIS ZIS = ZIS
ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js
-ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
+{-# INLINE ixsZipWith #-}
+ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k
ixsZipWith _ ZIS ZIS = ZIS
ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js
@@ -257,6 +302,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
+-- * Shaped shapes
+
-- | The shape of a shape-typed array given as a list of 'SNat' values.
--
-- Note that because the shape of a shape-typed array is known statically, you
@@ -264,7 +311,10 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))
type role ShS nominal
type ShS :: [Nat] -> Type
newtype ShS sh = ShS (ListS sh SNat)
- deriving (Eq, Ord, Generic)
+ deriving (Generic)
+
+instance Eq (ShS sh) where _ == _ = True
+instance Ord (ShS sh) where compare _ _ = EQ
pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
pattern ZSS = ShS ZS
@@ -309,9 +359,28 @@ shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+-- | This is a partial @const@ that fails when the second argument
+-- doesn't match the first.
+shsFromList :: ShS sh -> [Int] -> ShS sh
+shsFromList topsh topl = go topsh topl `seq` topsh
+ where
+ go :: ShS sh' -> [Int] -> ()
+ go ZSS [] = ()
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = go sh is
+ | otherwise = error $ "shsFromList: Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "shsFromList: Mismatched list length (type says "
+ ++ show (shsLength topsh) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+
+{-# INLINEABLE shsToList #-}
shsToList :: ShS sh -> [Int]
-shsToList ZSS = []
-shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) ->
+ let go :: ShS sh -> is
+ go ZSS = nil
+ go (sn :$$ sh) = fromSNat' sn `cons` go sh
+ in go topsh)
shsHead :: ShS (n : sh) -> SNat n
shsHead (ShS list) = listsHead list
@@ -356,7 +425,7 @@ instance KnownShS '[] where knownShS = ZSS
instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r
-withKnownShS k = withDict @(KnownShS sh) k
+withKnownShS = withDict @(KnownShS sh)
shsKnownShS :: ShS sh -> Dict KnownShS sh
shsKnownShS ZSS = Dict
@@ -366,18 +435,38 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh
shsOrthotopeShape ZSS = Dict
shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict
+-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'.
+-- This function may be removed in a future release.
+shsFromListS :: ListS sh f -> ShS sh
+shsFromListS ZS = ZSS
+shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l
+
+-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This
+-- function may be removed in a future release.
+shsFromIxS :: IxS sh i -> ShS sh
+shsFromIxS (IxS l) = shsFromListS l
+
+shsEnum :: ShS sh -> [IIxS sh]
+shsEnum = shsEnum'
+
+{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site
+shsEnum' :: Num i => ShS sh -> [IxS sh i]
+shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]]
+ where
+ suffixes = drop 1 (scanr (*) 1 (shsToList sh))
+
+ fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i
+ fromLin ZSS _ _ = ZIS
+ fromLin (_ :$$ sh') (I# suff# : suffs) i# =
+ let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh'
+ in fromIntegral (I# q#) :.$ fromLin sh' suffs r#
+ fromLin _ _ _ = error "impossible"
+
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = listsFromList (knownShS @sh)
toList = listsToList
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
@@ -389,15 +478,8 @@ instance KnownShS sh => IsList (IxS sh i) where
-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = shsFromList (knownShS @sh)
toList = shsToList
+
+$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |])
+{-# INLINEABLE ixsFromLinear #-}