diff options
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 57 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 4 |
2 files changed, 55 insertions, 6 deletions
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index 16f62fe..4dd0aa6 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -130,6 +130,9 @@ listxToList :: ListX sh' (Const i) -> [i] listxToList ZX = [] listxToList (Const i ::% is) = i : listxToList is +listxHead :: ListX (mn ': sh) f -> f mn +listxHead (i ::% _) = i + listxTail :: ListX (n : sh) i -> ListX sh i listxTail (_ ::% sh) = sh @@ -149,6 +152,19 @@ listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) listxLast (_ ::% sh@(_ ::% _)) = listxLast sh listxLast (x ::% ZX) = x +listxZip :: ListX sh (Const i) -> ListX sh (Const j) -> ListX sh (Const (i, j)) +listxZip ZX ZX = ZX +listxZip (Const i ::% irest) (Const j ::% jrest) = + Const (i, j) ::% listxZip irest jrest +--listxZip _ _ = error "listxZip: impossible pattern needlessly required" + +listxZipWith :: (i -> j -> k) -> ListX sh (Const i) -> ListX sh (Const j) + -> ListX sh (Const k) +listxZipWith _ ZX ZX = ZX +listxZipWith f (Const i ::% irest) (Const j ::% jrest) = + Const (f i j) ::% listxZipWith f irest jrest +--listxZipWith _ _ _ = error "listxZipWith: impossible pattern needlessly required" + -- * Mixed indices @@ -184,6 +200,12 @@ instance Foldable (IxX sh) where instance NFData i => NFData (IxX sh i) +ixxLength :: IxX sh i -> Int +ixxLength (IxX l) = listxLength l + +ixxRank :: IxX sh i -> SNat (Rank sh) +ixxRank (IxX l) = listxRank l + ixxZero :: StaticShX sh -> IIxX sh ixxZero ZKX = ZIX ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh @@ -195,6 +217,9 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i ixxFromList = coerce (listxFromList @_ @i) +ixxHead :: IxX (n : sh) i -> i +ixxHead (IxX list) = getConst (listxHead list) + ixxTail :: IxX (n : sh) i -> IxX sh i ixxTail (IxX list) = IxX (listxTail list) @@ -210,6 +235,12 @@ ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) +ixxZip :: IxX n i -> IxX n j -> IxX n (i, j) +ixxZip (IxX l1) (IxX l2) = IxX $ listxZip l1 l2 + +ixxZipWith :: (i -> j -> k) -> IxX n i -> IxX n j -> IxX n k +ixxZipWith f (IxX l1) (IxX l2) = IxX $ listxZipWith f l1 l2 + ixxFromLinear :: IShX sh -> Int -> IIxX sh ixxFromLinear = \sh i -> case go sh i of (idx, 0) -> idx @@ -305,12 +336,6 @@ instance NFData i => NFData (ShX sh i) where rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) -shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l - -shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX list) = listxRank list - -- | This checks only whether the types are equal; unknown dimensions might -- still differ. This corresponds to 'testEquality', except on the penultimate -- type parameter. @@ -340,6 +365,12 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') = Just Refl shxEqual _ _ = Nothing +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +shxRank :: ShX sh i -> SNat (Rank sh) +shxRank (ShX l) = listxRank l + -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 @@ -366,6 +397,9 @@ shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead (ShX list) = listxHead list + shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) @@ -446,12 +480,20 @@ infixr 3 :!% instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +instance NFData (StaticShX sh) where + rnf (StaticShX ZX) = () + rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + instance TestEquality StaticShX where testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 ssxLength :: StaticShX sh -> Int ssxLength (StaticShX l) = listxLength l +ssxRank :: StaticShX sh -> SNat (Rank sh) +ssxRank (StaticShX l) = listxRank l + -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality @@ -460,6 +502,9 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' +ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead (StaticShX list) = listxHead list + ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index 13675d0..736ced6 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -27,6 +27,7 @@ module Data.Array.Mixed.Types ( Replicate, lemReplicateSucc, MapJust, + Head, Tail, Init, Last, @@ -103,6 +104,9 @@ type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs +type family Head l where + Head (x : _) = x + type family Tail l where Tail (_ : xs) = xs |