aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ox-arrays.cabal19
-rw-r--r--src/Data/Array/Mixed/Shape.hs31
-rw-r--r--src/Data/Array/Mixed/Types.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs37
4 files changed, 83 insertions, 8 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 515d7ff..243dcd8 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -1,8 +1,11 @@
cabal-version: 3.0
name: ox-arrays
version: 0.1.0.0
+synopsis: An efficient CPU-based multidimensional array (tensor) library
+description: An efficient and richly typed CPU-based multidimensional array (tensor) library built upon the optimized tensor representation (strides list) implemented in the orthotope package.
author: Tom Smeding
license: BSD-3-Clause
+category: Array, Tensors
build-type: Simple
extra-source-files: cbits/arith_lists.h
@@ -61,11 +64,11 @@ library
build-depends:
strided-array-ops,
- base >=4.18 && <4.22,
- deepseq,
+ base,
+ deepseq < 1.7,
ghc-typelits-knownnat,
ghc-typelits-natnormalise,
- orthotope,
+ orthotope < 0.2,
template-haskell,
vector
hs-source-dirs: src
@@ -84,11 +87,11 @@ library strided-array-ops
Data.Array.Strided.Arith.Internal.Lists
Data.Array.Strided.Arith.Internal.Lists.TH
build-depends:
- base,
- ghc-typelits-knownnat,
- ghc-typelits-natnormalise,
- template-haskell,
- vector
+ base >=4.18 && <4.22,
+ ghc-typelits-knownnat < 1,
+ ghc-typelits-natnormalise < 1,
+ template-haskell < 3,
+ vector < 0.14
hs-source-dirs: ops
c-sources: cbits/arith.c
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index 99a137d..4dd0aa6 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -130,6 +130,9 @@ listxToList :: ListX sh' (Const i) -> [i]
listxToList ZX = []
listxToList (Const i ::% is) = i : listxToList is
+listxHead :: ListX (mn ': sh) f -> f mn
+listxHead (i ::% _) = i
+
listxTail :: ListX (n : sh) i -> ListX sh i
listxTail (_ ::% sh) = sh
@@ -149,6 +152,19 @@ listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
listxLast (x ::% ZX) = x
+listxZip :: ListX sh (Const i) -> ListX sh (Const j) -> ListX sh (Const (i, j))
+listxZip ZX ZX = ZX
+listxZip (Const i ::% irest) (Const j ::% jrest) =
+ Const (i, j) ::% listxZip irest jrest
+--listxZip _ _ = error "listxZip: impossible pattern needlessly required"
+
+listxZipWith :: (i -> j -> k) -> ListX sh (Const i) -> ListX sh (Const j)
+ -> ListX sh (Const k)
+listxZipWith _ ZX ZX = ZX
+listxZipWith f (Const i ::% irest) (Const j ::% jrest) =
+ Const (f i j) ::% listxZipWith f irest jrest
+--listxZipWith _ _ _ = error "listxZipWith: impossible pattern needlessly required"
+
-- * Mixed indices
@@ -201,6 +217,9 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)
+ixxHead :: IxX (n : sh) i -> i
+ixxHead (IxX list) = getConst (listxHead list)
+
ixxTail :: IxX (n : sh) i -> IxX sh i
ixxTail (IxX list) = IxX (listxTail list)
@@ -216,6 +235,12 @@ ixxInit = coerce (listxInit @(Const i))
ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))
+ixxZip :: IxX n i -> IxX n j -> IxX n (i, j)
+ixxZip (IxX l1) (IxX l2) = IxX $ listxZip l1 l2
+
+ixxZipWith :: (i -> j -> k) -> IxX n i -> IxX n j -> IxX n k
+ixxZipWith f (IxX l1) (IxX l2) = IxX $ listxZipWith f l1 l2
+
ixxFromLinear :: IShX sh -> Int -> IIxX sh
ixxFromLinear = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -372,6 +397,9 @@ shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+shxHead :: ShX (n : sh) i -> SMayNat i SNat n
+shxHead (ShX list) = listxHead list
+
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
@@ -474,6 +502,9 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
+ssxHead (StaticShX list) = listxHead list
+
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index 13675d0..736ced6 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -27,6 +27,7 @@ module Data.Array.Mixed.Types (
Replicate,
lemReplicateSucc,
MapJust,
+ Head,
Tail,
Init,
Last,
@@ -103,6 +104,9 @@ type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
+type family Head l where
+ Head (x : _) = x
+
type family Tail l where
Tail (_ : xs) = xs
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 5d5f8e3..102d9d8 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -135,6 +135,18 @@ listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
listrIndex _ ZR = error "k + 1 <= 0"
+listrZip :: ListR n i -> ListR n j -> ListR n (i, j)
+listrZip ZR ZR = ZR
+listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest
+listrZip _ _ = error "listrZip: impossible pattern needlessly required"
+
+listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k
+listrZipWith _ ZR ZR = ZR
+listrZipWith f (i ::: irest) (j ::: jrest) =
+ f i j ::: listrZipWith f irest jrest
+listrZipWith _ _ _ =
+ error "listrZipWith: impossible pattern needlessly required"
+
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
listrFromList perm $ \sperm ->
@@ -222,6 +234,12 @@ ixrLast (IxR list) = listrLast list
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
+ixrZip :: IxR n i -> IxR n j -> IxR n (i, j)
+ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
+
+ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
+ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
+
ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
@@ -434,6 +452,19 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+listsZip :: ListS sh (Const i) -> ListS sh (Const j) -> ListS sh (Const (i, j))
+listsZip ZS ZS = ZS
+listsZip (Const i ::$ irest) (Const j ::$ jrest) =
+ Const (i, j) ::$ listsZip irest jrest
+--listsZip _ _ = error "listsZip: impossible pattern needlessly required"
+
+listsZipWith :: (i -> j -> k) -> ListS sh (Const i) -> ListS sh (Const j)
+ -> ListS sh (Const k)
+listsZipWith _ ZS ZS = ZS
+listsZipWith f (Const i ::$ irest) (Const j ::$ jrest) =
+ Const (f i j) ::$ listsZipWith f irest jrest
+--listsZipWith _ _ _ = error "listsZipWith: impossible pattern needlessly required"
+
listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
@@ -530,6 +561,12 @@ ixsLast (IxS list) = getConst (listsLast list)
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
+ixsZip :: IxS n i -> IxS n j -> IxS n (i, j)
+ixsZip (IxS l1) (IxS l2) = IxS $ listsZip l1 l2
+
+ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k
+ixsZipWith f (IxS l1) (IxS l2) = IxS $ listsZipWith f l1 l2
+
ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
ixsPermutePrefix = coerce (listsPermutePrefix @(Const i))