{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB import Data.Bifunctor (first) import Data.Coerce import Data.Functor.Const import Data.Kind import Data.Proxy import Data.Type.Bool import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.IsList import GHC.TypeError import GHC.TypeLits import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce (unsafeCoerce) -- | Evidence for the constraint @c a@. data Dict c a where Dict :: c a => Dict c a -- | The 'SNat' pattern synonym is complete, but it doesn't have a -- @COMPLETE@ pragma. This copy of it does. pattern GHC_SNat :: () => KnownNat n => SNat n pattern GHC_SNat = SNat {-# COMPLETE GHC_SNat #-} fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat pattern SZ :: () => (n ~ 0) => SNat n pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) where SZ = SNat pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) where SS = snatSucc {-# COMPLETE SZ, SS #-} snatSucc :: SNat n -> SNat (n + 1) snatSucc SNat = SNat data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) snatPred snp1 = withKnownNat snp1 $ case cmpNat (Proxy @1) (Proxy @np1) of LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) GTI -> Nothing -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 lemAppNil :: l ++ '[] :~: l lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type data ListX sh f where ZX :: ListX '[] f (::%) :: f n -> ListX sh f -> ListX (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) infixr 3 ::% instance (forall n. Show (f n)) => Show (ListX sh f) where showsPrec _ = showListX shows data UnconsListXRes f sh1 = forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) unconsListX ZX = Nothing fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g fmapListX _ ZX = ZX fmapListX f (x ::% xs) = f x ::% fmapListX f xs foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m foldListX _ ZX = mempty foldListX f (x ::% xs) = f x <> foldListX f xs showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS showListX f l = showString "[" . go "" l . showString "]" where go :: String -> ListX sh' f -> ShowS go _ ZX = id go prefix (x ::% xs) = showString prefix . f x . go "," xs type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) deriving (Eq, Ord) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) where i :.% IxX shl = IxX (Const i ::% shl) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} type IIxX sh = IxX sh Int instance Show i => Show (IxX sh i) where showsPrec _ (IxX l) = showListX (\(Const i) -> shows i) l instance Functor (IxX sh) where fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) instance Foldable (IxX sh) where foldMap f (IxX l) = foldListX (f . getConst) l data SMayNat i f n where SUnknown :: i -> SMayNat i f Nothing SKnown :: f n -> SMayNat i f (Just n) deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s fromSMayNat' :: SMayNat Int SNat n -> Int fromSMayNat' = fromSMayNat id fromSNat' type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) deriving (Eq, Ord) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i pattern ZSX = ShX ZX pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => SMayNat i SNat n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) where i :$% ShX shl = ShX (i ::% shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int instance Show i => Show (ShX sh i) where showsPrec _ (ShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l instance Functor (ShX sh) where fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) lengthShX :: ShX sh i -> Int lengthShX ZSX = 0 lengthShX (_ :$% sh) = 1 + lengthShX sh -- | The part of a shape that is statically known. type StaticShX :: [Maybe Nat] -> Type newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh pattern ZKX = StaticShX ZX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) where i :!% StaticShX shl = StaticShX (i ::% shl) infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShX :: [Maybe Nat] -> Constraint class KnownShX sh where knownShX :: StaticShX sh instance KnownShX '[] where knownShX = ZKX instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX -- | Very untyped; length is checked at runtime. instance KnownShX sh => IsList (ListX sh (Const i)) where type Item (ListX sh (Const i)) = i fromList = go (knownShX @sh) where go :: StaticShX sh' -> [i] -> ListX sh' (Const i) go ZKX [] = ZX go (_ :!% sh) (i : is) = Const i ::% go sh is go _ _ = error "IsList(ListX): Mismatched list length" toList = go where go :: ListX sh' (Const i) -> [i] go ZX = [] go (Const i ::% is) = i : go is -- | Very untyped; length is checked at runtime, and index bounds are *not checked*. instance KnownShX sh => IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . fromList toList (IxX l) = toList l -- | Very untyped; length is checked at runtime, and known dimensions are *not checked*. -- instance KnownShX sh => IsList (ShX sh i) where -- type Item (ShX sh i) = i -- fromList = ShX . fmapListX (\(Const i) -> _) . fromList -- toList = go -- where -- go :: ShX sh' i -> [i] -- go ZSX = [] -- go (Const i :$% is) = i : go is type family Rank sh where Rank '[] = 0 Rank (_ : sh) = 1 + Rank sh type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh zeroIxX ZKX = ZIX zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh zeroIxX' :: IShX sh -> IIxX sh zeroIxX' ZSX = ZIX zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh -- This is a weird operation, so it has a long name completeShXzeros :: StaticShX sh -> IShX sh completeShXzeros ZKX = ZSX completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh -- TODO: generalise all these things to arbitrary @i@ ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx' shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') shAppend ZSX sh' = sh' shAppend (n :$% sh) sh' = n :$% shAppend sh sh' ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop long ZIX = long ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' shDropIx sh ZIX = sh shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh' ssxDropIx ssh ZIX = ssh ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx shTail :: IShX (n : sh) -> IShX sh shTail (_ :$% sh) = sh ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') shAppSplit _ ZKX idx = (ZSX, idx) shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' shapeSize :: IShX sh -> Int shapeSize ZSX = 1 shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. ssxToShape' :: StaticShX sh -> Maybe (IShX sh) ssxToShape' ZKX = Just ZSX ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh ssxToShape' (SUnknown _ :!% _) = Nothing lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' = SUnknown () :!% ssxReplicate n fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) go :: IShX sh -> Int -> (IIxX sh, Int) go ZSX i = (ZIX, i) go (n :$% sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` fromSMayNat' n in (locali :.% idx, upi) toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) go :: IShX sh -> IIxX sh -> (Int, Int) go ZSX ZIX = (0, 1) go (n :$% sh) (i :.% ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, fromSMayNat' n * sz) enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] shapeLshape :: IShX sh -> S.ShapeL shapeLshape ZSX = [] shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh ssxLength :: StaticShX sh -> Int ssxLength ZKX = 0 ssxLength (_ :!% ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKX = [] ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh staticShapeFrom :: IShX sh -> StaticShX sh staticShapeFrom ZSX = ZKX staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) lemKnownNatRank ZSX = Dict lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' go ZKX [] = ZSX go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar eqShX :: IShX sh1 -> IShX sh2 -> Bool eqShX ZSX ZSX = True eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2 eqShX _ _ = False -- | Will throw if the array does not have the casted-to shape. cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> StaticShX sh' -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (staticShapeFrom sh2) ssh' = let arrsh :: IShX sh1 (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if eqShX arrsh sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' type family AddMaybe n m where AddMaybe Nothing _ = Nothing AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -- This should be a function in base plusSNat :: SNat n -> SNat m -> SNat (n + m) plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX ssh = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a type family Elem x l where Elem x '[] = 'False Elem x (x : _) = 'True Elem x (_ : ys) = Elem x ys type family AllElem' as bs where AllElem' '[] bs = 'True AllElem' (a : as) bs = Elem a bs && AllElem' as bs type AllElem as bs = Assert (AllElem' as bs) (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) type family Count i n where Count n n = '[] Count i n = i : Count (i + 1) n type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) type family Index i sh where Index 0 (n : sh) = n Index i (_ : sh) = Index (i - 1) sh type family Permute is sh where Permute '[] sh = '[] Permute (i : is) sh = Index i sh : Permute is sh type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh data HList f list where HNil :: HList f '[] HCons :: f a -> HList f l -> HList f (a : l) infixr 5 `HCons` foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m foldHList _ HNil = mempty foldHList f (x `HCons` l) = f x <> foldHList f l type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is lemRankPermute _ HNil = Refl lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX HNil = Refl lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!% _) HNil = Refl lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) ssxTakeLen HNil _ = ZKX ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape" ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen HNil sh = sh ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape" ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute HNil _ = ZKX ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) ssxIndex _ _ SZ (n :!% _) rest = n :!% rest ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = ssxIndex p pT i sh rest ssxIndex _ _ _ ZKX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) | Dict <- lemKnownNatRankSSX ssh , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) -- | The list argument gives indices into the original dimension list. -- -- The permutation (the list) must have length <= @n@. If it is longer, this -- function throws. transposeUntyped :: forall n sh a. SNat n -> StaticShX sh -> [Int] -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a transposeUntyped sn ssh perm (XArray arr) | length perm <= fromSNat' sn , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) | otherwise = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr sumInner :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh = rerank ssh ssh' ZKX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' fromList1 :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList1 ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of SKnown m@GHC_SNat :!% _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 = XArray (S.reshape (shapeLshape sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthShX sh2) (S.shapeL arr)) arr)