aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs43
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs137
-rw-r--r--src/Data/Array/Mixed/Permutation.hs273
-rw-r--r--src/Data/Array/Mixed/Shape.hs586
-rw-r--r--src/Data/Array/Mixed/Types.hs134
-rw-r--r--src/Data/Array/Mixed/XArray.hs349
6 files changed, 0 insertions, 1522 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
deleted file mode 100644
index b1c7031..0000000
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ /dev/null
@@ -1,43 +0,0 @@
-{-# LANGUAGE ImportQualifiedPost #-}
-module Data.Array.Mixed.Internal.Arith (
- module Data.Array.Mixed.Internal.Arith,
- module Data.Array.Strided.Arith,
-) where
-
-import Data.Array.Internal qualified as OI
-import Data.Array.Internal.RankedG qualified as RG
-import Data.Array.Internal.RankedS qualified as RS
-
-import Data.Array.Strided qualified as AS
-import Data.Array.Strided.Arith
-
--- for liftVEltwise1
-import Foreign.Storable
-import GHC.TypeLits
-import Data.Vector.Storable qualified as VS
-import Data.Array.Strided.Arith.Internal (stridesDense)
-
-
-fromO :: RS.Array n a -> AS.Array n a
-fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec
-
-toO :: AS.Array n a -> RS.Array n a
-toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec))
-
-liftO1 :: (AS.Array n a -> AS.Array n' b)
- -> RS.Array n a -> RS.Array n' b
-liftO1 f = toO . f . fromO
-
-liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c)
- -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c
-liftO2 f x y = toO (f (fromO x) (fromO y))
-
-liftVEltwise1 :: (Storable a, Storable b)
- => SNat n
- -> (VS.Vector a -> VS.Vector b)
- -> RS.Array n a -> RS.Array n b
-liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
- | Just (blockOff, blockSz) <- stridesDense sh offset strides =
- let vec' = f (VS.slice blockOff blockSz vec)
- in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec'))
- | otherwise = RS.fromVector sh (f (RS.toVector arr))
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
deleted file mode 100644
index ec7e7bd..0000000
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ /dev/null
@@ -1,137 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Lemmas where
-
-import Data.Proxy
-import Data.Type.Equality
-import GHC.TypeLits
-
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-
-
--- * Reasoning helpers
-
-subst1 :: forall f a b. a :~: b -> f a :~: f b
-subst1 Refl = Refl
-
-subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
-subst2 Refl = Refl
-
-
--- * Lemmas
-
--- ** Nat
-
-lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
-lemLeqSuccSucc _ _ = unsafeCoerceRefl
-
-lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
-lemLeqPlus _ _ _ = Refl
-
-
--- ** Append
-
-lemAppNil :: l ++ '[] :~: l
-lemAppNil = unsafeCoerceRefl
-
-lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
-lemAppAssoc _ _ _ = unsafeCoerceRefl
-
-lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
-lemAppLeft _ Refl = Refl
-
-
--- ** Rank
-
-lemRankApp :: forall sh1 sh2.
- StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
-lemRankApp ZKX _ = Refl
-lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
- = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
- lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
- lemRankApp ssh1 ssh2
- where
- lem :: proxy a -> proxy b -> proxy c
- -> c :~: b + a
- -> b + a :~: c
- lem _ _ _ Refl = Refl
-
- lem2 :: proxy a -> proxy b -> proxy c
- -> (a + b :~: c)
- -> c + 1 :~: (a + 1 + b)
- lem2 _ _ _ Refl = Refl
-
-lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
-lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
-
-lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate SZ = Refl
-lemRankReplicate (SS (n :: SNat nm1))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
- , Refl <- lemRankReplicate n
- = Refl
-
-
--- ** Various type families
-
-lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
- -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp sn _ _ = go sn
- where
- go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
- go SZ = Refl
- go (SS (n :: SNat n'm1))
- | Refl <- lemReplicateSucc @a @n'm1
- , Refl <- go n
- = sym (lemReplicateSucc @a @(n'm1 + m))
-
-lemDropLenApp :: Rank l1 <= Rank l2
- => Proxy l1 -> Proxy l2 -> Proxy rest
- -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
-lemDropLenApp _ _ _ = unsafeCoerceRefl
-
-lemTakeLenApp :: Rank l1 <= Rank l2
- => Proxy l1 -> Proxy l2 -> Proxy rest
- -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
-lemTakeLenApp _ _ _ = unsafeCoerceRefl
-
-lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
-lemInitApp _ _ = unsafeCoerceRefl
-
-lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
-lemLastApp _ _ = unsafeCoerceRefl
-
-
--- ** KnownNat
-
-lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
-lemKnownNatSucc = Dict
-
-lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
-lemKnownNatRank ZSX = Dict
-lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
-
-lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
-lemKnownNatRankSSX ZKX = Dict
-lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
-
-
--- ** Known shapes
-
-lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
-lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
-
-lemKnownShX :: StaticShX sh -> Dict KnownShX sh
-lemKnownShX ZKX = Dict
-lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
-lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
deleted file mode 100644
index 8efcbe8..0000000
--- a/src/Data/Array/Mixed/Permutation.hs
+++ /dev/null
@@ -1,273 +0,0 @@
-{-# LANGUAGE ConstraintKinds #-}
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Permutation where
-
-import Data.Coerce (coerce)
-import Data.Functor.Const
-import Data.List (sort)
-import Data.Maybe (fromMaybe)
-import Data.Proxy
-import Data.Type.Bool
-import Data.Type.Equality
-import Data.Type.Ord
-import GHC.TypeError
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-
-
--- * Permutations
-
--- | A "backward" permutation of a dimension list. The operation on the
--- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
--- for code that implements this.
-data Perm list where
- PNil :: Perm '[]
- PCons :: SNat a -> Perm l -> Perm (a : l)
-infixr 5 `PCons`
-deriving instance Show (Perm list)
-deriving instance Eq (Perm list)
-
-permRank :: Perm list -> SNat (Rank list)
-permRank PNil = SNat
-permRank (_ `PCons` l) | SNat <- permRank l = SNat
-
-permFromList :: [Int] -> (forall list. Perm list -> r) -> r
-permFromList [] k = k PNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
-
-permToList :: Perm list -> [Natural]
-permToList PNil = mempty
-permToList (x `PCons` l) = TN.fromSNat x : permToList l
-
-permToList' :: Perm list -> [Int]
-permToList' = map fromIntegral . permToList
-
--- | When called as @permCheckPermutation p k@, if @p@ is a permutation of
--- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't,
--- then @Nothing@ is returned.
-permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r
-permCheckPermutation = \p k ->
- let n = permRank p
- in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
- (Just Refl, Just Refl) -> Just k
- _ -> Nothing
- where
- lemElemCount :: (0 <= n, Compare n m ~ LT)
- => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
- lemElemCount _ _ = unsafeCoerceRefl
-
- lemCount :: (OrdCond (Compare i n) True False True ~ True)
- => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
- lemCount _ _ = unsafeCoerceRefl
-
- lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
- lemElem _ _ = unsafeCoerceRefl
-
- provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
- -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
- provePerm1 _ _ PNil = Just (Refl)
- provePerm1 p rtop@SNat (PCons sn@SNat perm)
- | Just Refl <- provePerm1 p rtop perm
- = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
- (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- _ -> Nothing
- | otherwise
- = Nothing
-
- provePerm2 :: SNat i -> SNat n -> Perm is'
- -> Maybe (AllElem' (Count i n) is' :~: True)
- provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
- case cmpNat i n of
- EQI -> Just Refl
- LTI | Refl <- lemCount i n
- , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
- -> checkElem i perm
- | otherwise -> Nothing
- GTI -> error "unreachable"
- where
- checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
- checkElem _ PNil = Nothing
- checkElem i@SNat (PCons k@SNat perm :: Perm is') =
- case sameNat i k of
- Just Refl -> Just Refl
- Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
- | otherwise -> Nothing
-
--- | Utility class for generating permutations from type class information.
-class KnownPerm l where makePerm :: Perm l
-instance KnownPerm '[] where makePerm = PNil
-instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm
-
--- | Untyped permutations for ranked arrays
-type PermR = [Int]
-
-
--- ** Applying permutations
-
-type family Elem x l where
- Elem x '[] = 'False
- Elem x (x : _) = 'True
- Elem x (_ : ys) = Elem x ys
-
-type family AllElem' as bs where
- AllElem' '[] bs = 'True
- AllElem' (a : as) bs = Elem a bs && AllElem' as bs
-
-type AllElem as bs = Assert (AllElem' as bs)
- (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
-
-type family Count i n where
- Count n n = '[]
- Count i n = i : Count (i + 1) n
-
-type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
-
-type family Index i sh where
- Index 0 (n : sh) = n
- Index i (_ : sh) = Index (i - 1) sh
-
-type family Permute is sh where
- Permute '[] sh = '[]
- Permute (i : is) sh = Index i sh : Permute is sh
-
-type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
-
-type family TakeLen ref l where
- TakeLen '[] l = '[]
- TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
-
-type family DropLen ref l where
- DropLen '[] l = l
- DropLen (_ : ref) (_ : xs) = DropLen ref xs
-
-listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
-listxTakeLen PNil _ = ZX
-listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
-listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape"
-
-listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
-listxDropLen PNil sh = sh
-listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
-listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape"
-
-listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
-listxPermute PNil _ = ZX
-listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
-
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
-listxIndex _ _ SZ (n ::% _) = n
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh
-listxIndex _ _ _ ZX = error "Index into empty shape"
-
-listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
-listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
-
-ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
-ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
-
-ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
-
-ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
-
-ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
-
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
-
-ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
-
-shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-
-
--- * Operations on permutations
-
-permInverse :: Perm is
- -> (forall is'.
- IsPermutation is'
- => Perm is'
- -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
- -> r)
- -> r
-permInverse = \perm k ->
- genPerm perm $ \(invperm :: Perm is') ->
- fromMaybe
- (error $ "permInverse: did not generate permutation? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)
- (permCheckPermutation invperm
- (k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
- Just eq -> eq
- Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)))
- where
- genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
- genPerm perm =
- let permList = permToList' perm
- in toHList $ map snd (sort (zip permList [0..]))
- where
- toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
- toHList [] k = k PNil
- toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
-
- provePermInverse :: Perm is -> Perm is' -> StaticShX sh
- -> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh =
- ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
-
-type family MapSucc is where
- MapSucc '[] = '[]
- MapSucc (i : is) = i + 1 : MapSucc is
-
-permShift1 :: Perm l -> Perm (0 : MapSucc l)
-permShift1 = (SNat @0 `PCons`) . permMapSucc
- where
- permMapSucc :: Perm l -> Perm (MapSucc l)
- permMapSucc PNil = PNil
- permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
-
-
--- * Lemmas
-
-lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
-lemRankPermute _ PNil = Refl
-lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
-
-lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
- => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
-lemRankDropLen ZKX PNil = Refl
-lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
-lemRankDropLen (_ :!% _) PNil = Refl
-lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
-
-lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
- -> Index (i + 1) (a : l) :~: Index i l
-lemIndexSucc _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
deleted file mode 100644
index b49e005..0000000
--- a/src/Data/Array/Mixed/Shape.hs
+++ /dev/null
@@ -1,586 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE RoleAnnotations #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Shape where
-
-import Control.DeepSeq (NFData(..))
-import Data.Bifunctor (first)
-import Data.Coerce
-import Data.Foldable qualified as Foldable
-import Data.Functor.Const
-import Data.Kind (Type, Constraint)
-import Data.Monoid (Sum(..))
-import Data.Proxy
-import Data.Type.Equality
-import GHC.Exts (withDict)
-import GHC.Generics (Generic)
-import GHC.IsList (IsList)
-import GHC.IsList qualified as IsList
-import GHC.TypeLits
-
-import Data.Array.Mixed.Types
-
-
--- | The length of a type-level list. If the argument is a shape, then the
--- result is the rank of that shape.
-type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = Rank sh + 1
-
-
--- * Mixed lists
-
-type role ListX nominal representational
-type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
-data ListX sh f where
- ZX :: ListX '[] f
- (::%) :: f n -> ListX sh f -> ListX (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
-infixr 3 ::%
-
-instance (forall n. Show (f n)) => Show (ListX sh f) where
- showsPrec _ = listxShow shows
-
-instance (forall n. NFData (f n)) => NFData (ListX sh f) where
- rnf ZX = ()
- rnf (x ::% l) = rnf x `seq` rnf l
-
-data UnconsListXRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
-listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
-listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
-listxUncons ZX = Nothing
-
--- | This checks only whether the types are equal; if the elements of the list
--- are not singletons, their values may still differ. This corresponds to
--- 'testEquality', except on the penultimate type parameter.
-listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEqType ZX ZX = Just Refl
-listxEqType (n ::% sh) (m ::% sh')
- | Just Refl <- testEquality n m
- , Just Refl <- listxEqType sh sh'
- = Just Refl
-listxEqType _ _ = Nothing
-
--- | This checks whether the two lists actually contain equal values. This is
--- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
--- in the @some@ package (except on the penultimate type parameter).
-listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEqual ZX ZX = Just Refl
-listxEqual (n ::% sh) (m ::% sh')
- | Just Refl <- testEquality n m
- , n == m
- , Just Refl <- listxEqual sh sh'
- = Just Refl
-listxEqual _ _ = Nothing
-
-listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
-listxFmap _ ZX = ZX
-listxFmap f (x ::% xs) = f x ::% listxFmap f xs
-
-listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
-listxFold _ ZX = mempty
-listxFold f (x ::% xs) = f x <> listxFold f xs
-
-listxLength :: ListX sh f -> Int
-listxLength = getSum . listxFold (\_ -> Sum 1)
-
-listxRank :: ListX sh f -> SNat (Rank sh)
-listxRank ZX = SNat
-listxRank (_ ::% l) | SNat <- listxRank l = SNat
-
-listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
-listxShow f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListX sh' f -> ShowS
- go _ ZX = id
- go prefix (x ::% xs) = showString prefix . f x . go "," xs
-
-listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i)
-listxFromList topssh topl = go topssh topl
- where
- go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
- go ZKX [] = ZX
- go (_ :!% sh) (i : is) = Const i ::% go sh is
- go _ _ = error $ "listxFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
-
-listxToList :: ListX sh' (Const i) -> [i]
-listxToList ZX = []
-listxToList (Const i ::% is) = i : listxToList is
-
-listxHead :: ListX (mn ': sh) f -> f mn
-listxHead (i ::% _) = i
-
-listxTail :: ListX (n : sh) i -> ListX sh i
-listxTail (_ ::% sh) = sh
-
-listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
-listxAppend ZX idx' = idx'
-listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
-
-listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f
-listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh
-listxInit (_ ::% ZX) = ZX
-
-listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh))
-listxLast (_ ::% sh@(_ ::% _)) = listxLast sh
-listxLast (x ::% ZX) = x
-
-
--- * Mixed indices
-
--- | This is a newtype over 'ListX'.
-type role IxX nominal representational
-type IxX :: [Maybe Nat] -> Type -> Type
-newtype IxX sh i = IxX (ListX sh (Const i))
- deriving (Eq, Ord, Generic)
-
-pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
-pattern ZIX = IxX ZX
-
-pattern (:.%)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => i -> IxX sh i -> IxX sh1 i
-pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
- where i :.% IxX shl = IxX (Const i ::% shl)
-infixr 3 :.%
-
-{-# COMPLETE ZIX, (:.%) #-}
-
-type IIxX sh = IxX sh Int
-
-instance Show i => Show (IxX sh i) where
- showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l
-
-instance Functor (IxX sh) where
- fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
-
-instance Foldable (IxX sh) where
- foldMap f (IxX l) = listxFold (f . getConst) l
-
-instance NFData i => NFData (IxX sh i)
-
-ixxLength :: IxX sh i -> Int
-ixxLength (IxX l) = listxLength l
-
-ixxRank :: IxX sh i -> SNat (Rank sh)
-ixxRank (IxX l) = listxRank l
-
-ixxZero :: StaticShX sh -> IIxX sh
-ixxZero ZKX = ZIX
-ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
-
-ixxZero' :: IShX sh -> IIxX sh
-ixxZero' ZSX = ZIX
-ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
-
-ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
-ixxFromList = coerce (listxFromList @_ @i)
-
-ixxHead :: IxX (n : sh) i -> i
-ixxHead (IxX list) = getConst (listxHead list)
-
-ixxTail :: IxX (n : sh) i -> IxX sh i
-ixxTail (IxX list) = IxX (listxTail list)
-
-ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
-ixxAppend = coerce (listxAppend @_ @(Const i))
-
-ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
-ixxDrop = coerce (listxDrop @(Const i) @(Const i))
-
-ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i
-ixxInit = coerce (listxInit @(Const i))
-
-ixxLast :: forall n sh i. IxX (n : sh) i -> i
-ixxLast = coerce (listxLast @(Const i))
-
-ixxFromLinear :: IShX sh -> Int -> IIxX sh
-ixxFromLinear = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
- where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$% sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSMayNat' n
- in (locali :.% idx, upi)
-
-ixxToLinear :: IShX sh -> IIxX sh -> Int
-ixxToLinear = \sh i -> fst (go sh i)
- where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$% sh) (i :.% ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
-
-
--- * Mixed shapes
-
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
-
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
- rnf (SUnknown i) = rnf i
- rnf (SKnown x) = rnf x
-
-instance TestEquality f => TestEquality (SMayNat i f) where
- testEquality SUnknown{} SUnknown{} = Just Refl
- testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
- testEquality _ _ = Nothing
-
-fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => f m -> r)
- -> SMayNat i f n -> r
-fromSMayNat f _ (SUnknown i) = f i
-fromSMayNat _ g (SKnown s) = g s
-
-fromSMayNat' :: SMayNat Int SNat n -> Int
-fromSMayNat' = fromSMayNat id fromSNat'
-
-type family AddMaybe n m where
- AddMaybe Nothing _ = Nothing
- AddMaybe (Just _) Nothing = Nothing
- AddMaybe (Just n) (Just m) = Just (n + m)
-
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
-smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
-smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
-smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
-
-
--- | This is a newtype over 'ListX'.
-type role ShX nominal representational
-type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
- deriving (Eq, Ord, Generic)
-
-pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
-pattern ZSX = ShX ZX
-
-pattern (:$%)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
-pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
- where i :$% ShX shl = ShX (i ::% shl)
-infixr 3 :$%
-
-{-# COMPLETE ZSX, (:$%) #-}
-
-type IShX sh = ShX sh Int
-
-instance Show i => Show (ShX sh i) where
- showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
-
-instance Functor (ShX sh) where
- fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
-
-instance NFData i => NFData (ShX sh i) where
- rnf (ShX ZX) = ()
- rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
-
--- | This checks only whether the types are equal; unknown dimensions might
--- still differ. This corresponds to 'testEquality', except on the penultimate
--- type parameter.
-shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
-shxEqType ZSX ZSX = Just Refl
-shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- shxEqType sh sh'
- = Just Refl
-shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
- | Just Refl <- shxEqType sh sh'
- = Just Refl
-shxEqType _ _ = Nothing
-
--- | This checks whether all dimensions have the same value. This is more than
--- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
--- @some@ package (except on the penultimate type parameter).
-shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
-shxEqual ZSX ZSX = Just Refl
-shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- shxEqual sh sh'
- = Just Refl
-shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
- | i == j
- , Just Refl <- shxEqual sh sh'
- = Just Refl
-shxEqual _ _ = Nothing
-
-shxLength :: ShX sh i -> Int
-shxLength (ShX l) = listxLength l
-
-shxRank :: ShX sh i -> SNat (Rank sh)
-shxRank (ShX l) = listxRank l
-
--- | The number of elements in an array described by this shape.
-shxSize :: IShX sh -> Int
-shxSize ZSX = 1
-shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
-
-shxFromList :: StaticShX sh -> [Int] -> ShX sh Int
-shxFromList topssh topl = go topssh topl
- where
- go :: StaticShX sh' -> [Int] -> ShX sh' Int
- go ZKX [] = ZSX
- go (SKnown sn :!% sh) (i : is)
- | i == fromSNat' sn = SKnown sn :$% go sh is
- | otherwise = error $ "shxFromList: Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is
- go _ _ = error $ "shxFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
-
-shxToList :: IShX sh -> [Int]
-shxToList ZSX = []
-shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
-
-shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
-
-shxHead :: ShX (n : sh) i -> SMayNat i SNat n
-shxHead (ShX list) = listxHead list
-
-shxTail :: ShX (n : sh) i -> ShX sh i
-shxTail (ShX list) = ShX (listxTail list)
-
-shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
-shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
-
-shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
-shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-
-shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
-shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
-
-shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listxInit @(SMayNat i SNat))
-
-shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
-shxLast = coerce (listxLast @(SMayNat i SNat))
-
-shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shxTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
-
--- This is a weird operation, so it has a long name
-shxCompleteZeros :: StaticShX sh -> IShX sh
-shxCompleteZeros ZKX = ZSX
-shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
-shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
-
-shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
-shxSplitApp _ ZKX idx = (ZSX, idx)
-shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
-
-shxEnum :: IShX sh -> [IIxX sh]
-shxEnum = \sh -> go sh id []
- where
- go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZSX f = (f ZIX :)
- go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
-
-shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh')
-shxCast ZSX ZKX = Just ZSX
-shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh
-shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh
-shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh
-shxCast _ _ = Nothing
-
--- | Partial version of 'shxCast'.
-shxCast' :: IShX sh -> StaticShX sh' -> IShX sh'
-shxCast' sh ssh = case shxCast sh ssh of
- Just sh' -> sh'
- Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")"
-
-
--- * Static mixed shapes
-
--- | The part of a shape that is statically known. (A newtype over 'ListX'.)
-type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
- deriving (Eq, Ord)
-
-pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
-pattern ZKX = StaticShX ZX
-
-pattern (:!%)
- :: forall {sh1}.
- forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
-pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
- where i :!% StaticShX shl = StaticShX (i ::% shl)
-infixr 3 :!%
-
-{-# COMPLETE ZKX, (:!%) #-}
-
-instance Show (StaticShX sh) where
- showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
-
-instance NFData (StaticShX sh) where
- rnf (StaticShX ZX) = ()
- rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l)
- rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l)
-
-instance TestEquality StaticShX where
- testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
-
-ssxLength :: StaticShX sh -> Int
-ssxLength (StaticShX l) = listxLength l
-
-ssxRank :: StaticShX sh -> SNat (Rank sh)
-ssxRank (StaticShX l) = listxRank l
-
--- | @ssxEqType = 'testEquality'@. Provided for consistency.
-ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
-ssxEqType = testEquality
-
-ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKX sh' = sh'
-ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
-
-ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
-ssxHead (StaticShX list) = listxHead list
-
-ssxTail :: StaticShX (n : sh) -> StaticShX sh
-ssxTail (_ :!% ssh) = ssh
-
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
-
-ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listxInit @(SMayNat () SNat))
-
-ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
-ssxLast = coerce (listxLast @(SMayNat () SNat))
-
--- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
-ssxToShX' ZKX = Just ZSX
-ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
-ssxToShX' (SUnknown _ :!% _) = Nothing
-
-ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
-ssxReplicate SZ = ZKX
-ssxReplicate (SS (n :: SNat n'))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
- = SUnknown () :!% ssxReplicate n
-
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
-
-ssxFromShape :: IShX sh -> StaticShX sh
-ssxFromShape ZSX = ZKX
-ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
-
-ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
-ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
-
-
--- | Evidence for the static part of a shape. This pops up only when you are
--- polymorphic in the element type of an array.
-type KnownShX :: [Maybe Nat] -> Constraint
-class KnownShX sh where knownShX :: StaticShX sh
-instance KnownShX '[] where knownShX = ZKX
-instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
-instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
-
-withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r
-withKnownShX sh = withDict @(KnownShX sh) sh
-
-
--- * Flattening
-
-type Flatten sh = Flatten' 1 sh
-
-type family Flatten' acc sh where
- Flatten' acc '[] = Just acc
- Flatten' acc (Nothing : sh) = Nothing
- Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-
--- This function is currently unused
-ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
-ssxFlatten = go (SNat @1)
- where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
- go acc ZKX = SKnown acc
- go _ (SUnknown () :!% _) = SUnknown ()
- go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
-
-shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
-shxFlatten = go (SNat @1)
- where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
- go acc ZSX = SKnown acc
- go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
- go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
-
- goUnknown :: Int -> IShX sh -> Int
- goUnknown acc ZSX = acc
- goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
- goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
-
-
--- | Very untyped: only length is checked (at runtime).
-instance KnownShX sh => IsList (ListX sh (Const i)) where
- type Item (ListX sh (Const i)) = i
- fromList = listxFromList (knownShX @sh)
- toList = listxToList
-
--- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
-instance KnownShX sh => IsList (IxX sh i) where
- type Item (IxX sh i) = i
- fromList = IxX . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length and known dimensions are checked (at runtime).
-instance KnownShX sh => IsList (ShX sh Int) where
- type Item (ShX sh Int) = Int
- fromList = shxFromList (knownShX @sh)
- toList = shxToList
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
deleted file mode 100644
index 736ced6..0000000
--- a/src/Data/Array/Mixed/Types.hs
+++ /dev/null
@@ -1,134 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.Types (
- -- * Reified evidence of a type class
- Dict(..),
-
- -- * Type-level naturals
- pattern SZ, pattern SS,
- fromSNat', sameNat',
- snatPlus, snatMinus, snatMul,
- snatSucc,
-
- -- * Type-level lists
- type (++),
- Replicate,
- lemReplicateSucc,
- MapJust,
- Head,
- Tail,
- Init,
- Last,
-
- -- * Unsafe
- unsafeCoerceRefl,
-) where
-
-import Data.Type.Equality
-import Data.Proxy
-import GHC.TypeLits
-import GHC.TypeNats qualified as TN
-import Unsafe.Coerce qualified
-
-
--- | Evidence for the constraint @c a@.
-data Dict c a where
- Dict :: c a => Dict c a
-
-fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
-
-sameNat' :: SNat n -> SNat m -> Maybe (n :~: m)
-sameNat' n@SNat m@SNat = sameNat n m
-
-pattern SZ :: () => (n ~ 0) => SNat n
-pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
- where SZ = SNat
-
-pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
-pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
- where SS = snatSucc
-
-{-# COMPLETE SZ, SS #-}
-
-snatSucc :: SNat n -> SNat (n + 1)
-snatSucc SNat = SNat
-
-data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
-snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
-snatPred snp1 =
- withKnownNat snp1 $
- case cmpNat (Proxy @1) (Proxy @np1) of
- LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- GTI -> Nothing
-
--- This should be a function in base
-snatPlus :: SNat n -> SNat m -> SNat (n + m)
-snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
-
--- This should be a function in base
-snatMinus :: SNat n -> SNat m -> SNat (n - m)
-snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce
-
--- This should be a function in base
-snatMul :: SNat n -> SNat m -> SNat (n * m)
-snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
-
-
--- | Type-level list append.
-type family l1 ++ l2 where
- '[] ++ l2 = l2
- (x : xs) ++ l2 = x : xs ++ l2
-
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerceRefl
-
-type family MapJust l where
- MapJust '[] = '[]
- MapJust (x : xs) = Just x : MapJust xs
-
-type family Head l where
- Head (x : _) = x
-
-type family Tail l where
- Tail (_ : xs) = xs
-
-type family Init l where
- Init (x : y : xs) = x : Init (y : xs)
- Init '[x] = '[]
-
-type family Last l where
- Last (x : y : xs) = Last (y : xs)
- Last '[x] = x
-
-
--- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
--- only typecheck for actual type equalities. One cannot, e.g. accidentally
--- write this:
---
--- @
--- foo :: Proxy a -> Proxy b -> a :~: b
--- foo = unsafeCoerceRefl
--- @
---
--- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce',
--- but would have resulted in interesting memory errors at runtime.
-unsafeCoerceRefl :: a :~: b
-unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
deleted file mode 100644
index 93484dc..0000000
--- a/src/Data/Array/Mixed/XArray.hs
+++ /dev/null
@@ -1,349 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE NoStarIsType #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Mixed.XArray where
-
-import Control.DeepSeq (NFData)
-import Data.Array.Internal.RankedG qualified as ORG
-import Data.Array.Internal.RankedS qualified as ORS
-import Data.Array.Internal qualified as OI
-import Data.Array.Ranked qualified as ORB
-import Data.Array.RankedS qualified as S
-import Data.Coerce
-import Data.Foldable (toList)
-import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import Data.Type.Ord
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Generics (Generic)
-import GHC.TypeLits
-
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-
-
-type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (Rank sh) a)
- deriving (Show, Eq, Ord, Generic)
-
-instance NFData (XArray sh a)
-
-
-shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
-shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
- where
- go :: StaticShX sh' -> [Int] -> IShX sh'
- go ZKX [] = ZSX
- go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
- go _ _ = error "Invalid shapeL"
-
-fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
-fromVector sh v
- | Dict <- lemKnownNatRank sh
- = XArray (S.fromVector (shxToList sh) v)
-
-toVector :: Storable a => XArray sh a -> VS.Vector a
-toVector (XArray arr) = S.toVector arr
-
--- | This allows observing the strides in the underlying orthotope array. This
--- can be useful for optimisation, but should be considered an implementation
--- detail: strides may change in new versions of this library without notice.
-arrayStrides :: XArray sh a -> [Int]
-arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides
-
-scalar :: Storable a => a -> XArray '[] a
-scalar = XArray . S.scalar
-
--- | Will throw if the array does not have the casted-to shape.
-cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> StaticShX sh'
- -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
-cast ssh1 sh2 ssh' (XArray arr)
- | Refl <- lemRankApp ssh1 ssh'
- , Refl <- lemRankApp (ssxFromShape sh2) ssh'
- = let arrsh :: IShX sh1
- (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
- in if shxToList arrsh == shxToList sh2
- then XArray arr
- else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
-
-unScalar :: Storable a => XArray '[] a -> a
-unScalar (XArray a) = S.unScalar a
-
-replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
-replicate sh ssh' (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh'
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
- , Refl <- lemRankApp (ssxFromShape sh) ssh'
- = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
- S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $
- arr)
-
-replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
-replicateScal sh x
- | Dict <- lemKnownNatRank sh
- = XArray (S.constant (shxToList sh) x)
-
-generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
-generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-
--- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownNatRank sh =
--- XArray . S.fromVector (shxShapeL sh)
--- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
-
-indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
-indexPartial (XArray arr) ZIX = XArray arr
-indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
-
-index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
-index xarr i
- | Refl <- lemAppNil @sh
- = let XArray arr' = indexPartial xarr i :: XArray '[] a
- in S.unScalar arr'
-
-append :: forall n m sh a. Storable a
- => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
-append ssh (XArray a) (XArray b)
- | Dict <- lemKnownNatRankSSX ssh
- = XArray (S.append a b)
-
--- | All arrays must have the same shape, except possibly for the outermost
--- dimension.
-concat :: Storable a
- => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a
-concat ssh l
- | Dict <- lemKnownNatRankSSX ssh
- = XArray (S.concatOuter (coerce (toList l)))
-
--- | If the prefix of the shape of the input array (@sh@) is empty (i.e.
--- contains a zero), then there is no way to deduce the full shape of the output
--- array (more precisely, the @sh2@ part): that could only come from calling
--- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
--- this case; we choose to fill the shape with zeros wherever we cannot deduce
--- what it should be.
---
--- For example, if:
---
--- @
--- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21]
--- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float
--- @
---
--- then:
---
--- @
--- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float
--- @
---
--- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@
--- in this shape: we don't know if @f@ intended to return an array with shape 0
--- here (it probably didn't), but there is no better number to put here absent
--- a subarray of the input to pass to @f@.
---
--- In this particular case the fact that @sh@ is empty was evident from the
--- type-level information, but the same situation occurs when @sh@ consists of
--- @Nothing@s, and some of those happen to be zero at runtime.
-rerank :: forall sh sh1 sh2 a b.
- (Storable a, Storable b)
- => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
- -> (XArray sh1 a -> XArray sh2 b)
- -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
-rerank ssh ssh1 ssh2 f xarr@(XArray arr)
- | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
- in if any (== 0) (shxToList sh)
- then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
- else case () of
- () | Dict <- lemKnownNatRankSSX ssh
- , Dict <- lemKnownNatRankSSX ssh2
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
- (\a -> let XArray r = f (XArray a) in r)
- arr)
-
-rerankTop :: forall sh1 sh2 sh a b.
- (Storable a, Storable b)
- => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
- -> (XArray sh1 a -> XArray sh2 b)
- -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
-rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
-
--- | The caveat about empty arrays at @rerank@ applies here too.
-rerank2 :: forall sh sh1 sh2 a b c.
- (Storable a, Storable b, Storable c)
- => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
- -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
- -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
-rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
- | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
- in if any (== 0) (shxToList sh)
- then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
- else case () of
- () | Dict <- lemKnownNatRankSSX ssh
- , Dict <- lemKnownNatRankSSX ssh2
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
- (\a b -> let XArray r = f (XArray a) (XArray b) in r)
- arr1 arr2)
-
--- | The list argument gives indices into the original dimension list.
-transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh)
- => StaticShX sh
- -> Perm is
- -> XArray sh a
- -> XArray (PermutePrefix is sh) a
-transpose ssh perm (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh
- , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
- , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
- , Refl <- lemRankDropLen ssh perm
- = XArray (S.transpose (permToList' perm) arr)
-
--- | The list argument gives indices into the original dimension list.
---
--- The permutation (the list) must have length <= @n@. If it is longer, this
--- function throws.
-transposeUntyped :: forall n sh a.
- SNat n -> StaticShX sh -> [Int]
- -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
-transposeUntyped sn ssh perm (XArray arr)
- | length perm <= fromSNat' sn
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
- = XArray (S.transpose perm arr)
- | otherwise
- = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
-
-transpose2 :: forall sh1 sh2 a.
- StaticShX sh1 -> StaticShX sh2
- -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
-transpose2 ssh1 ssh2 (XArray arr)
- | Refl <- lemRankApp ssh1 ssh2
- , Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
- , Refl <- lemRankAppComm ssh1 ssh2
- , let n1 = ssxLength ssh1
- = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-
-sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
-sumFull _ (XArray arr) =
- S.unScalar $
- liftO1 (numEltSum1Inner (SNat @0)) $
- S.fromVector [product (S.shapeL arr)] $
- S.toVector arr
-
-sumInner :: forall sh sh' a. (Storable a, NumElt a)
- => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
-sumInner ssh ssh' arr
- | Refl <- lemAppNil @sh
- = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
- sh'F = shxFlatten sh' :$% ZSX
- ssh'F = ssxFromShape sh'F
-
- go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
- go (XArray arr')
- | Refl <- lemRankApp ssh ssh'F
- , let sn = listxRank (let StaticShX l = ssh in l)
- = XArray (liftO1 (numEltSum1Inner sn) arr')
-
- in go $
- transpose2 ssh'F ssh $
- reshapePartial ssh' ssh sh'F $
- transpose2 ssh ssh' $
- arr
-
-sumOuter :: forall sh sh' a. (Storable a, NumElt a)
- => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
-sumOuter ssh ssh' arr
- | Refl <- lemAppNil @sh
- = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
- shF = shxFlatten sh :$% ZSX
- in sumInner ssh' (ssxFromShape shF) $
- transpose2 (ssxFromShape shF) ssh' $
- reshapePartial ssh ssh' shF $
- arr
-
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX ssh
- = case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
-
-toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr) =
- case S.shapeL arr of
- 0 : _ -> []
- _ -> coerce (ORB.toList (S.unravel arr))
-
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
-
-toList1 :: Storable a => XArray '[n] a -> [a]
-toList1 (XArray arr) = S.toList arr
-
--- | Throws if the given shape is not, in fact, empty.
-empty :: forall sh a. Storable a => IShX sh -> XArray sh a
-empty sh
- | Dict <- lemKnownNatRank sh
- , shxSize sh == 0
- = XArray (S.fromVector (shxToList sh) VS.empty)
- | otherwise
- = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh
-
-slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
-slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr)
-
-sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a
-sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr)
-
-rev1 :: XArray (n : sh) a -> XArray (n : sh) a
-rev1 (XArray arr) = XArray (S.rev [0] arr)
-
--- | Throws if the given array and the target shape do not have the same number of elements.
-reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
-reshape ssh1 sh2 (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh1
- , Dict <- lemKnownNatRank sh2
- = XArray (S.reshape (shxToList sh2) arr)
-
--- | Throws if the given array and the target shape do not have the same number of elements.
-reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
-reshapePartial ssh1 ssh' sh2 (XArray arr)
- | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
- = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
-
--- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
-iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a
-iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))