diff options
-rw-r--r-- | ox-arrays.cabal | 19 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 36 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 66 |
4 files changed, 113 insertions, 12 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 515d7ff..243dcd8 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -1,8 +1,11 @@ cabal-version: 3.0 name: ox-arrays version: 0.1.0.0 +synopsis: An efficient CPU-based multidimensional array (tensor) library +description: An efficient and richly typed CPU-based multidimensional array (tensor) library built upon the optimized tensor representation (strides list) implemented in the orthotope package. author: Tom Smeding license: BSD-3-Clause +category: Array, Tensors build-type: Simple extra-source-files: cbits/arith_lists.h @@ -61,11 +64,11 @@ library build-depends: strided-array-ops, - base >=4.18 && <4.22, - deepseq, + base, + deepseq < 1.7, ghc-typelits-knownnat, ghc-typelits-natnormalise, - orthotope, + orthotope < 0.2, template-haskell, vector hs-source-dirs: src @@ -84,11 +87,11 @@ library strided-array-ops Data.Array.Strided.Arith.Internal.Lists Data.Array.Strided.Arith.Internal.Lists.TH build-depends: - base, - ghc-typelits-knownnat, - ghc-typelits-natnormalise, - template-haskell, - vector + base >=4.18 && <4.22, + ghc-typelits-knownnat < 1, + ghc-typelits-natnormalise < 1, + template-haskell < 3, + vector < 0.14 hs-source-dirs: ops c-sources: cbits/arith.c diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index 80bd55e..4dd0aa6 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -130,6 +130,9 @@ listxToList :: ListX sh' (Const i) -> [i] listxToList ZX = [] listxToList (Const i ::% is) = i : listxToList is +listxHead :: ListX (mn ': sh) f -> f mn +listxHead (i ::% _) = i + listxTail :: ListX (n : sh) i -> ListX sh i listxTail (_ ::% sh) = sh @@ -149,6 +152,19 @@ listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) listxLast (_ ::% sh@(_ ::% _)) = listxLast sh listxLast (x ::% ZX) = x +listxZip :: ListX sh (Const i) -> ListX sh (Const j) -> ListX sh (Const (i, j)) +listxZip ZX ZX = ZX +listxZip (Const i ::% irest) (Const j ::% jrest) = + Const (i, j) ::% listxZip irest jrest +--listxZip _ _ = error "listxZip: impossible pattern needlessly required" + +listxZipWith :: (i -> j -> k) -> ListX sh (Const i) -> ListX sh (Const j) + -> ListX sh (Const k) +listxZipWith _ ZX ZX = ZX +listxZipWith f (Const i ::% irest) (Const j ::% jrest) = + Const (f i j) ::% listxZipWith f irest jrest +--listxZipWith _ _ _ = error "listxZipWith: impossible pattern needlessly required" + -- * Mixed indices @@ -201,6 +217,9 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i ixxFromList = coerce (listxFromList @_ @i) +ixxHead :: IxX (n : sh) i -> i +ixxHead (IxX list) = getConst (listxHead list) + ixxTail :: IxX (n : sh) i -> IxX sh i ixxTail (IxX list) = IxX (listxTail list) @@ -216,6 +235,12 @@ ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) +ixxZip :: IxX n i -> IxX n j -> IxX n (i, j) +ixxZip (IxX l1) (IxX l2) = IxX $ listxZip l1 l2 + +ixxZipWith :: (i -> j -> k) -> IxX n i -> IxX n j -> IxX n k +ixxZipWith f (IxX l1) (IxX l2) = IxX $ listxZipWith f l1 l2 + ixxFromLinear :: IShX sh -> Int -> IIxX sh ixxFromLinear = \sh i -> case go sh i of (idx, 0) -> idx @@ -372,6 +397,9 @@ shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead (ShX list) = listxHead list + shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) @@ -452,6 +480,11 @@ infixr 3 :!% instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +instance NFData (StaticShX sh) where + rnf (StaticShX ZX) = () + rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + instance TestEquality StaticShX where testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 @@ -469,6 +502,9 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' +ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead (StaticShX list) = listxHead list + ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index 13675d0..736ced6 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -27,6 +27,7 @@ module Data.Array.Mixed.Types ( Replicate, lemReplicateSucc, MapJust, + Head, Tail, Init, Last, @@ -103,6 +104,9 @@ type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +type family Head l where + Head (x : _) = x + type family Tail l where Tail (_ : xs) = xs diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index 878ea7e..102d9d8 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -25,6 +26,7 @@ {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Shape where +import Control.DeepSeq (NFData (..)) import Data.Array.Shape qualified as O import Data.Array.Mixed.Types import Data.Coerce (coerce) @@ -35,6 +37,7 @@ import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (withDict) +import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits @@ -59,6 +62,10 @@ infixr 3 ::: instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows +instance NFData i => NFData (ListR n i) where + rnf ZR = () + rnf (x ::: l) = rnf x `seq` rnf l + data UnconsListRRes i n1 = forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) @@ -128,6 +135,18 @@ listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs listrIndex _ ZR = error "k + 1 <= 0" +listrZip :: ListR n i -> ListR n j -> ListR n (i, j) +listrZip ZR ZR = ZR +listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest +listrZip _ _ = error "listrZip: impossible pattern needlessly required" + +listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k +listrZipWith _ ZR ZR = ZR +listrZipWith f (i ::: irest) (j ::: jrest) = + f i j ::: listrZipWith f irest jrest +listrZipWith _ _ _ = + error "listrZipWith: impossible pattern needlessly required" + listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> listrFromList perm $ \sperm -> @@ -157,7 +176,7 @@ listrPermutePrefix = \perm sh -> type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord) + deriving (Eq, Ord, Generic) deriving newtype (Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i @@ -178,6 +197,8 @@ type IIxR n = IxR n Int instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l +instance NFData i => NFData (IxR sh i) + ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l @@ -213,6 +234,12 @@ ixrLast (IxR list) = listrLast list ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) +ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) +ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 + +ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k +ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 + ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) @@ -220,7 +247,7 @@ ixrPermutePrefix = coerce (listrPermutePrefix @i) type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord) + deriving (Eq, Ord, Generic) deriving newtype (Functor, Foldable) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i @@ -241,6 +268,8 @@ type IShR n = ShR n Int instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = listrShow shows l +instance NFData i => NFData (ShR sh i) + shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n shCvtXR' ZSX = castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) @@ -346,6 +375,10 @@ infixr 3 ::$ instance (forall n. Show (f n)) => Show (ListS sh f) where showsPrec _ = listsShow shows +instance (forall m. NFData (f m)) => NFData (ListS n f) where + rnf ZS = () + rnf (x ::$ l) = rnf x `seq` rnf l + data UnconsListSRes f sh1 = forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) @@ -419,6 +452,19 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' +listsZip :: ListS sh (Const i) -> ListS sh (Const j) -> ListS sh (Const (i, j)) +listsZip ZS ZS = ZS +listsZip (Const i ::$ irest) (Const j ::$ jrest) = + Const (i, j) ::$ listsZip irest jrest +--listsZip _ _ = error "listsZip: impossible pattern needlessly required" + +listsZipWith :: (i -> j -> k) -> ListS sh (Const i) -> ListS sh (Const j) + -> ListS sh (Const k) +listsZipWith _ ZS ZS = ZS +listsZipWith f (Const i ::$ irest) (Const j ::$ jrest) = + Const (f i j) ::$ listsZipWith f irest jrest +--listsZipWith _ _ _ = error "listsZipWith: impossible pattern needlessly required" + listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh @@ -454,7 +500,7 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord) + deriving (Eq, Ord, Generic) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -480,6 +526,8 @@ instance Functor (IxS sh) where instance Foldable (IxS sh) where foldMap f (IxS l) = listsFold (f . getConst) l +instance NFData i => NFData (IxS sh i) + ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -513,6 +561,12 @@ ixsLast (IxS list) = getConst (listsLast list) ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) +ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip (IxS l1) (IxS l2) = IxS $ listsZip l1 l2 + +ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +ixsZipWith f (IxS l1) (IxS l2) = IxS $ listsZipWith f l1 l2 + ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) @@ -524,7 +578,7 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) type role ShS nominal type ShS :: [Nat] -> Type newtype ShS sh = ShS (ListS sh SNat) - deriving (Eq, Ord) + deriving (Eq, Ord, Generic) pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh pattern ZSS = ShS ZS @@ -543,6 +597,10 @@ infixr 3 :$$ instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l +instance NFData (ShS sh) where + rnf (ShS ZS) = () + rnf (ShS (SNat ::$ l)) = rnf (ShS l) + instance TestEquality ShS where testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 |