aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Shape.hs')
-rw-r--r--src/Data/Array/Mixed/Shape.hs57
1 files changed, 51 insertions, 6 deletions
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index 16f62fe..4dd0aa6 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -130,6 +130,9 @@ listxToList :: ListX sh' (Const i) -> [i]
listxToList ZX = []
listxToList (Const i ::% is) = i : listxToList is
+listxHead :: ListX (mn ': sh) f -> f mn
+listxHead (i ::% _) = i
+
listxTail :: ListX (n : sh) i -> ListX sh i
listxTail (_ ::% sh) = sh
@@ -149,6 +152,19 @@ listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
listxLast (x ::% ZX) = x
+listxZip :: ListX sh (Const i) -> ListX sh (Const j) -> ListX sh (Const (i, j))
+listxZip ZX ZX = ZX
+listxZip (Const i ::% irest) (Const j ::% jrest) =
+ Const (i, j) ::% listxZip irest jrest
+--listxZip _ _ = error "listxZip: impossible pattern needlessly required"
+
+listxZipWith :: (i -> j -> k) -> ListX sh (Const i) -> ListX sh (Const j)
+ -> ListX sh (Const k)
+listxZipWith _ ZX ZX = ZX
+listxZipWith f (Const i ::% irest) (Const j ::% jrest) =
+ Const (f i j) ::% listxZipWith f irest jrest
+--listxZipWith _ _ _ = error "listxZipWith: impossible pattern needlessly required"
+
-- * Mixed indices
@@ -184,6 +200,12 @@ instance Foldable (IxX sh) where
instance NFData i => NFData (IxX sh i)
+ixxLength :: IxX sh i -> Int
+ixxLength (IxX l) = listxLength l
+
+ixxRank :: IxX sh i -> SNat (Rank sh)
+ixxRank (IxX l) = listxRank l
+
ixxZero :: StaticShX sh -> IIxX sh
ixxZero ZKX = ZIX
ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
@@ -195,6 +217,9 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)
+ixxHead :: IxX (n : sh) i -> i
+ixxHead (IxX list) = getConst (listxHead list)
+
ixxTail :: IxX (n : sh) i -> IxX sh i
ixxTail (IxX list) = IxX (listxTail list)
@@ -210,6 +235,12 @@ ixxInit = coerce (listxInit @(Const i))
ixxLast :: forall n sh i. IxX (n : sh) i -> i
ixxLast = coerce (listxLast @(Const i))
+ixxZip :: IxX n i -> IxX n j -> IxX n (i, j)
+ixxZip (IxX l1) (IxX l2) = IxX $ listxZip l1 l2
+
+ixxZipWith :: (i -> j -> k) -> IxX n i -> IxX n j -> IxX n k
+ixxZipWith f (IxX l1) (IxX l2) = IxX $ listxZipWith f l1 l2
+
ixxFromLinear :: IShX sh -> Int -> IIxX sh
ixxFromLinear = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -305,12 +336,6 @@ instance NFData i => NFData (ShX sh i) where
rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
-shxLength :: ShX sh i -> Int
-shxLength (ShX l) = listxLength l
-
-shxRank :: ShX sh i -> SNat (Rank sh)
-shxRank (ShX list) = listxRank list
-
-- | This checks only whether the types are equal; unknown dimensions might
-- still differ. This corresponds to 'testEquality', except on the penultimate
-- type parameter.
@@ -340,6 +365,12 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
= Just Refl
shxEqual _ _ = Nothing
+shxLength :: ShX sh i -> Int
+shxLength (ShX l) = listxLength l
+
+shxRank :: ShX sh i -> SNat (Rank sh)
+shxRank (ShX l) = listxRank l
+
-- | The number of elements in an array described by this shape.
shxSize :: IShX sh -> Int
shxSize ZSX = 1
@@ -366,6 +397,9 @@ shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+shxHead :: ShX (n : sh) i -> SMayNat i SNat n
+shxHead (ShX list) = listxHead list
+
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
@@ -446,12 +480,20 @@ infixr 3 :!%
instance Show (StaticShX sh) where
showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+instance NFData (StaticShX sh) where
+ rnf (StaticShX ZX) = ()
+ rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
+ rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
+
instance TestEquality StaticShX where
testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l
+ssxRank :: StaticShX sh -> SNat (Rank sh)
+ssxRank (StaticShX l) = listxRank l
+
-- | @ssxEqType = 'testEquality'@. Provided for consistency.
ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
ssxEqType = testEquality
@@ -460,6 +502,9 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
+ssxHead (StaticShX list) = listxHead list
+
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh