{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB import Data.Coerce import Data.Kind import Data.Proxy import Data.Type.Bool import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.TypeError import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -- | Evidence for the constraint @c a@. data Dict c a where Dict :: c a => Dict c a -- | The 'SNat' pattern synonym is complete, but it doesn't have a -- @COMPLETE@ pragma. This copy of it does. pattern GHC_SNat :: () => KnownNat n => SNat n pattern GHC_SNat = SNat {-# COMPLETE GHC_SNat #-} fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat pattern SZ :: () => (n ~ 0) => SNat n pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) where SZ = SNat pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) where SS = snatSucc {-# COMPLETE SZ, SS #-} snatSucc :: SNat n -> SNat (n + 1) snatSucc SNat = SNat data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) snatPred snp1 = withKnownNat snp1 $ case cmpNat (Proxy @1) (Proxy @np1) of LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) GTI -> Nothing -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 lemAppNil :: l ++ '[] :~: l lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i deriving instance Show i => Show (IxX sh i) deriving instance Eq i => Eq (IxX sh i) deriving instance Ord i => Ord (IxX sh i) deriving instance Functor (IxX sh) deriving instance Foldable (IxX sh) infixr 3 :.@ infixr 3 :.? type IIxX sh = IxX sh Int type ShX :: [Maybe Nat] -> Type -> Type data ShX sh i where ZSX :: ShX '[] i (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i deriving instance Show i => Show (ShX sh i) deriving instance Eq i => Eq (ShX sh i) deriving instance Ord i => Ord (ShX sh i) deriving instance Functor (ShX sh) deriving instance Foldable (ShX sh) infixr 3 :$@ infixr 3 :$? type IShX sh = ShX sh Int -- | The part of a shape that is statically known. type StaticShX :: [Maybe Nat] -> Type data StaticShX sh where ZKSX :: StaticShX '[] (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) deriving instance Show (StaticShX sh) infixr 3 :!$@ infixr 3 :!$? -- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where knownShapeX :: StaticShX sh instance KnownShapeX '[] where knownShapeX = ZKSX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where knownShapeX = natSing :!$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :!$? knownShapeX type family Rank sh where Rank '[] = 0 Rank (_ : sh) = 1 + Rank sh type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh zeroIxX ZKSX = ZIX zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh zeroIxX' :: IShX sh -> IIxX sh zeroIxX' ZSX = ZIX zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh -- This is a weird operation, so it has a long name completeShXzeros :: StaticShX sh -> IShX sh completeShXzeros ZKSX = ZSX completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') shAppend ZSX sh' = sh' shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' shAppend (n :$? sh) sh' = n :$? shAppend sh sh' ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop sh ZIX = sh ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKSX sh' = sh' ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' shapeSize :: IShX sh -> Int shapeSize ZSX = 1 shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh shapeSize (n :$? sh) = n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. ssxToShape' :: StaticShX sh -> Maybe (IShX sh) ssxToShape' ZKSX = Just ZSX ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh ssxToShape' (_ :!$? _) = Nothing lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKSX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' = () :!$? ssxReplicate n fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) go :: IShX sh -> Int -> (IIxX sh, Int) go ZSX i = (ZIX, i) go (n :$@ sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` fromSNat' n in (locali :.@ idx, upi) go (n :$? sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali :.? idx, upi) toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) go :: IShX sh -> IIxX sh -> (Int, Int) go ZSX ZIX = (0, 1) go (n :$@ sh) (i :.@ ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, fromSNat' n * sz) go (n :$? sh) (i :.? ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] go ZSX f = (f ZIX :) go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] shapeLshape :: IShX sh -> S.ShapeL shapeLshape ZSX = [] shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh shapeLshape (n :$? sh) = n : shapeLshape sh ssxLength :: StaticShX sh -> Int ssxLength ZKSX = 0 ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh ssxLength (_ :!$? ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKSX = [] ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh staticShapeFrom :: IShX sh -> StaticShX sh staticShapeFrom ZSX = ZKSX staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) lemKnownNatRank ZSX = Dict lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) lemKnownNatRankSSX ZKSX = Dict lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh lemKnownShapeX ZKSX = Dict lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict lemAppKnownShapeX (() :!$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' go ZKSX [] = ZSX go (n :!$@ ssh) (_ : l) = n :$@ go ssh l go (() :!$? ssh) (n : l) = n :$? go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' type family AddMaybe n m where AddMaybe Nothing _ = Nothing AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) append :: forall n m sh a. (KnownShapeX sh, Storable a) => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a type family Elem x l where Elem x '[] = 'False Elem x (x : _) = 'True Elem x (_ : ys) = Elem x ys type family AllElem' as bs where AllElem' '[] bs = 'True AllElem' (a : as) bs = Elem a bs && AllElem' as bs type AllElem as bs = Assert (AllElem' as bs) (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) type family Count i n where Count n n = '[] Count i n = i : Count (i + 1) n type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) type family Index i sh where Index 0 (n : sh) = n Index i (_ : sh) = Index (i - 1) sh type family Permute is sh where Permute '[] sh = '[] Permute (i : is) sh = Index i sh : Permute is sh type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh data HList f list where HNil :: HList f '[] HCons :: f a -> HList f l -> HList f (a : l) infixr 5 `HCons` foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m foldHList _ HNil = mempty foldHList f (x `HCons` l) = f x <> foldHList f l class KnownNatList l where makeNatList :: HList SNat l instance KnownNatList '[] where makeNatList = HNil instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is lemRankPermute _ HNil = Refl lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKSX HNil = Refl lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!$@ _) HNil = Refl lemRankDropLen (_ :!$? _) HNil = Refl lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) ssxTakeLen HNil _ = ZKSX ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen HNil sh = sh ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute HNil _ = ZKSX ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') = ssxIndex p pT i sh rest ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') = ssxIndex p pT i sh rest ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) => HList SNat is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose perm (XArray arr) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen (knownShapeX @sh) perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) -- | The list argument gives indices into the original dimension list. -- -- The permutation (the list) must have length <= @n@. If it is longer, this -- function throws. transposeUntyped :: forall n sh a. SNat n -> StaticShX sh -> [Int] -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a transposeUntyped sn ssh perm (XArray arr) | length perm <= fromSNat' sn , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) | otherwise = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr sumInner :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh = rerank ssh ssh' ZKSX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' fromList1 :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList1 ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: [(Int, Int)] -> XArray sh a -> XArray sh a slice ivs (XArray arr) = XArray (S.slice ivs arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 = XArray (S.reshape (shapeLshape sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)