diff options
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 43 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 137 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 273 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 586 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 134 | ||||
-rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 349 |
6 files changed, 0 insertions, 1522 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs deleted file mode 100644 index b1c7031..0000000 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ /dev/null @@ -1,43 +0,0 @@ -{-# LANGUAGE ImportQualifiedPost #-} -module Data.Array.Mixed.Internal.Arith ( - module Data.Array.Mixed.Internal.Arith, - module Data.Array.Strided.Arith, -) where - -import Data.Array.Internal qualified as OI -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS - -import Data.Array.Strided qualified as AS -import Data.Array.Strided.Arith - --- for liftVEltwise1 -import Foreign.Storable -import GHC.TypeLits -import Data.Vector.Storable qualified as VS -import Data.Array.Strided.Arith.Internal (stridesDense) - - -fromO :: RS.Array n a -> AS.Array n a -fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec - -toO :: AS.Array n a -> RS.Array n a -toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) - -liftO1 :: (AS.Array n a -> AS.Array n' b) - -> RS.Array n a -> RS.Array n' b -liftO1 f = toO . f . fromO - -liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) - -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c -liftO2 f x y = toO (f (fromO x) (fromO y)) - -liftVEltwise1 :: (Storable a, Storable b) - => SNat n - -> (VS.Vector a -> VS.Vector b) - -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just (blockOff, blockSz) <- stridesDense sh offset strides = - let vec' = f (VS.slice blockOff blockSz vec) - in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs deleted file mode 100644 index ec7e7bd..0000000 --- a/src/Data/Array/Mixed/Lemmas.hs +++ /dev/null @@ -1,137 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types - - --- * Reasoning helpers - -subst1 :: forall f a b. a :~: b -> f a :~: f b -subst1 Refl = Refl - -subst2 :: forall f c a b. a :~: b -> f a c :~: f b c -subst2 Refl = Refl - - --- * Lemmas - --- ** Nat - -lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True -lemLeqSuccSucc _ _ = unsafeCoerceRefl - -lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True -lemLeqPlus _ _ _ = Refl - - --- ** Append - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerceRefl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerceRefl - -lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l -lemAppLeft _ Refl = Refl - - --- ** Rank - -lemRankApp :: forall sh1 sh2. - StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp ZKX _ = Refl -lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 - = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $ - lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $ - lemRankApp ssh1 ssh2 - where - lem :: proxy a -> proxy b -> proxy c - -> c :~: b + a - -> b + a :~: c - lem _ _ _ Refl = Refl - - lem2 :: proxy a -> proxy b -> proxy c - -> (a + b :~: c) - -> c + 1 :~: (a + 1 + b) - lem2 _ _ _ Refl = Refl - -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this - -lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate SZ = Refl -lemRankReplicate (SS (n :: SNat nm1)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 - , Refl <- lemRankReplicate n - = Refl - - --- ** Various type families - -lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a - -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp sn _ _ = go sn - where - go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS (n :: SNat n'm1)) - | Refl <- lemReplicateSucc @a @n'm1 - , Refl <- go n - = sym (lemReplicateSucc @a @(n'm1 + m)) - -lemDropLenApp :: Rank l1 <= Rank l2 - => Proxy l1 -> Proxy l2 -> Proxy rest - -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) -lemDropLenApp _ _ _ = unsafeCoerceRefl - -lemTakeLenApp :: Rank l1 <= Rank l2 - => Proxy l1 -> Proxy l2 -> Proxy rest - -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) -lemTakeLenApp _ _ _ = unsafeCoerceRefl - -lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l -lemInitApp _ _ = unsafeCoerceRefl - -lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x -lemLastApp _ _ = unsafeCoerceRefl - - --- ** KnownNat - -lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) -lemKnownNatSucc = Dict - -lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh) -lemKnownNatRank ZSX = Dict -lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict - -lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX ZKX = Dict -lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict - - --- ** Known shapes - -lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) -lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) - -lemKnownShX :: StaticShX sh -> Dict KnownShX sh -lemKnownShX ZKX = Dict -lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict -lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs deleted file mode 100644 index 8efcbe8..0000000 --- a/src/Data/Array/Mixed/Permutation.hs +++ /dev/null @@ -1,273 +0,0 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Permutation where - -import Data.Coerce (coerce) -import Data.Functor.Const -import Data.List (sort) -import Data.Maybe (fromMaybe) -import Data.Proxy -import Data.Type.Bool -import Data.Type.Equality -import Data.Type.Ord -import GHC.TypeError -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types - - --- * Permutations - --- | A "backward" permutation of a dimension list. The operation on the --- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' --- for code that implements this. -data Perm list where - PNil :: Perm '[] - PCons :: SNat a -> Perm l -> Perm (a : l) -infixr 5 `PCons` -deriving instance Show (Perm list) -deriving instance Eq (Perm list) - -permRank :: Perm list -> SNat (Rank list) -permRank PNil = SNat -permRank (_ `PCons` l) | SNat <- permRank l = SNat - -permFromList :: [Int] -> (forall list. Perm list -> r) -> r -permFromList [] k = k PNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `PCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x - -permToList :: Perm list -> [Natural] -permToList PNil = mempty -permToList (x `PCons` l) = TN.fromSNat x : permToList l - -permToList' :: Perm list -> [Int] -permToList' = map fromIntegral . permToList - --- | When called as @permCheckPermutation p k@, if @p@ is a permutation of --- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, --- then @Nothing@ is returned. -permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r -permCheckPermutation = \p k -> - let n = permRank p - in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of - (Just Refl, Just Refl) -> Just k - _ -> Nothing - where - lemElemCount :: (0 <= n, Compare n m ~ LT) - => proxy n -> proxy m -> Elem n (Count 0 m) :~: True - lemElemCount _ _ = unsafeCoerceRefl - - lemCount :: (OrdCond (Compare i n) True False True ~ True) - => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n - lemCount _ _ = unsafeCoerceRefl - - lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True - lemElem _ _ = unsafeCoerceRefl - - provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) - provePerm1 _ _ PNil = Just (Refl) - provePerm1 p rtop@SNat (PCons sn@SNat perm) - | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - _ -> Nothing - | otherwise - = Nothing - - provePerm2 :: SNat i -> SNat n -> Perm is' - -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm - -> checkElem i perm - | otherwise -> Nothing - GTI -> error "unreachable" - where - checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) - checkElem _ PNil = Nothing - checkElem i@SNat (PCons k@SNat perm :: Perm is') = - case sameNat i k of - Just Refl -> Just Refl - Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl - | otherwise -> Nothing - --- | Utility class for generating permutations from type class information. -class KnownPerm l where makePerm :: Perm l -instance KnownPerm '[] where makePerm = PNil -instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm - --- | Untyped permutations for ranked arrays -type PermR = [Int] - - --- ** Applying permutations - -type family Elem x l where - Elem x '[] = 'False - Elem x (x : _) = 'True - Elem x (_ : ys) = Elem x ys - -type family AllElem' as bs where - AllElem' '[] bs = 'True - AllElem' (a : as) bs = Elem a bs && AllElem' as bs - -type AllElem as bs = Assert (AllElem' as bs) - (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) - -type family Count i n where - Count n n = '[] - Count i n = i : Count (i + 1) n - -type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) - -type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh - -type family Permute is sh where - Permute '[] sh = '[] - Permute (i : is) sh = Index i sh : Permute is sh - -type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh - -type family TakeLen ref l where - TakeLen '[] l = '[] - TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs - -type family DropLen ref l where - DropLen '[] l = l - DropLen (_ : ref) (_ : xs) = DropLen ref xs - -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen PNil _ = ZX -listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen PNil sh = sh -listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh - -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" - -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) - -ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) - -ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - - --- * Operations on permutations - -permInverse :: Perm is - -> (forall is'. - IsPermutation is' - => Perm is' - -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) - -> r) - -> r -permInverse = \perm k -> - genPerm perm $ \(invperm :: Perm is') -> - fromMaybe - (error $ "permInverse: did not generate permutation? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) - (permCheckPermutation invperm - (k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm))) - where - genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r - genPerm perm = - let permList = permToList' perm - in toHList $ map snd (sort (zip permList [0..])) - where - toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r - toHList [] k = k PNil - toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) - - provePermInverse :: Perm is -> Perm is' -> StaticShX sh - -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = - ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh - -type family MapSucc is where - MapSucc '[] = '[] - MapSucc (i : is) = i + 1 : MapSucc is - -permShift1 :: Perm l -> Perm (0 : MapSucc l) -permShift1 = (SNat @0 `PCons`) . permMapSucc - where - permMapSucc :: Perm l -> Perm (MapSucc l) - permMapSucc PNil = PNil - permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns - - --- * Lemmas - -lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is -lemRankPermute _ PNil = Refl -lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl - -lemRankDropLen :: forall is sh. (Rank is <= Rank sh) - => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKX PNil = Refl -lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!% _) PNil = Refl -lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l - -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs deleted file mode 100644 index b49e005..0000000 --- a/src/Data/Array/Mixed/Shape.hs +++ /dev/null @@ -1,586 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Shape where - -import Control.DeepSeq (NFData(..)) -import Data.Bifunctor (first) -import Data.Coerce -import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Kind (Type, Constraint) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import GHC.Exts (withDict) -import GHC.Generics (Generic) -import GHC.IsList (IsList) -import GHC.IsList qualified as IsList -import GHC.TypeLits - -import Data.Array.Mixed.Types - - --- | The length of a type-level list. If the argument is a shape, then the --- result is the rank of that shape. -type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 - - --- * Mixed lists - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -instance (forall n. Show (f n)) => Show (ListX sh f) where - showsPrec _ = listxShow shows - -instance (forall n. NFData (f n)) => NFData (ListX sh f) where - rnf ZX = () - rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) -listxUncons ZX = Nothing - --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqType ZX ZX = Just Refl -listxEqType (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , Just Refl <- listxEqType sh sh' - = Just Refl -listxEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqual ZX ZX = Just Refl -listxEqual (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listxEqual sh sh' - = Just Refl -listxEqual _ _ = Nothing - -listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -listxFmap _ ZX = ZX -listxFmap f (x ::% xs) = f x ::% listxFmap f xs - -listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFold _ ZX = mempty -listxFold f (x ::% xs) = f x <> listxFold f xs - -listxLength :: ListX sh f -> Int -listxLength = getSum . listxFold (\_ -> Sum 1) - -listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat - -listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -listxShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListX sh' f -> ShowS - go _ ZX = id - go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) -listxFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "listxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" - -listxToList :: ListX sh' (Const i) -> [i] -listxToList ZX = [] -listxToList (Const i ::% is) = i : listxToList is - -listxHead :: ListX (mn ': sh) f -> f mn -listxHead (i ::% _) = i - -listxTail :: ListX (n : sh) i -> ListX sh i -listxTail (_ ::% sh) = sh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f -listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh -listxInit (_ ::% ZX) = ZX - -listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) -listxLast (_ ::% sh@(_ ::% _)) = listxLast sh -listxLast (x ::% ZX) = x - - --- * Mixed indices - --- | This is a newtype over 'ListX'. -type role IxX nominal representational -type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) - -pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i -pattern ZIX = IxX ZX - -pattern (:.%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) -infixr 3 :.% - -{-# COMPLETE ZIX, (:.%) #-} - -type IIxX sh = IxX sh Int - -instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l - -instance Functor (IxX sh) where - fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) - -instance Foldable (IxX sh) where - foldMap f (IxX l) = listxFold (f . getConst) l - -instance NFData i => NFData (IxX sh i) - -ixxLength :: IxX sh i -> Int -ixxLength (IxX l) = listxLength l - -ixxRank :: IxX sh i -> SNat (Rank sh) -ixxRank (IxX l) = listxRank l - -ixxZero :: StaticShX sh -> IIxX sh -ixxZero ZKX = ZIX -ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh - -ixxZero' :: IShX sh -> IIxX sh -ixxZero' ZSX = ZIX -ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh - -ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i -ixxFromList = coerce (listxFromList @_ @i) - -ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = getConst (listxHead list) - -ixxTail :: IxX (n : sh) i -> IxX sh i -ixxTail (IxX list) = IxX (listxTail list) - -ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @(Const i)) - -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixxDrop = coerce (listxDrop @(Const i) @(Const i)) - -ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @(Const i)) - -ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @(Const i)) - -ixxFromLinear :: IShX sh -> Int -> IIxX sh -ixxFromLinear = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) - -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) - - --- * Mixed shapes - -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) - -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where - rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x - -instance TestEquality f => TestEquality (SMayNat i f) where - testEquality SUnknown{} SUnknown{} = Just Refl - testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl - testEquality _ _ = Nothing - -fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int SNat n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) - - --- | This is a newtype over 'ListX'. -type role ShX nominal representational -type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} - -type IShX sh = ShX sh Int - -instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l - -instance Functor (ShX sh) where - fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) - --- | This checks only whether the types are equal; unknown dimensions might --- still differ. This corresponds to 'testEquality', except on the penultimate --- type parameter. -shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') -shxEqType ZSX ZSX = Just Refl -shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m - , Just Refl <- shxEqType sh sh' - = Just Refl -shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') - | Just Refl <- shxEqType sh sh' - = Just Refl -shxEqType _ _ = Nothing - --- | This checks whether all dimensions have the same value. This is more than --- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the --- @some@ package (except on the penultimate type parameter). -shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') -shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m - , Just Refl <- shxEqual sh sh' - = Just Refl -shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') - | i == j - , Just Refl <- shxEqual sh sh' - = Just Refl -shxEqual _ _ = Nothing - -shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l - -shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l - --- | The number of elements in an array described by this shape. -shxSize :: IShX sh -> Int -shxSize ZSX = 1 -shxSize (n :$% sh) = fromSMayNat' n * shxSize sh - -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int -shxFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [Int] -> ShX sh' Int - go ZKX [] = ZSX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn :$% go sh is - | otherwise = error $ "shxFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is - go _ _ = error $ "shxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" - -shxToList :: IShX sh -> [Int] -shxToList ZSX = [] -shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh - -shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -shxHead :: ShX (n : sh) i -> SMayNat i SNat n -shxHead (ShX list) = listxHead list - -shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) - -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) - -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) - -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) - -shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) - -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) - -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh - --- This is a weird operation, so it has a long name -shxCompleteZeros :: StaticShX sh -> IShX sh -shxCompleteZeros ZKX = ZSX -shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh -shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh - -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) -shxSplitApp _ ZKX idx = (ZSX, idx) -shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) - -shxEnum :: IShX sh -> [IIxX sh] -shxEnum = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh -shxCast _ _ = Nothing - --- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of - Just sh' -> sh' - Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" - - --- * Static mixed shapes - --- | The part of a shape that is statically known. (A newtype over 'ListX'.) -type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) - -pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX - -pattern (:!%) - :: forall {sh1}. - forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) -infixr 3 :!% - -{-# COMPLETE ZKX, (:!%) #-} - -instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l - -instance NFData (StaticShX sh) where - rnf (StaticShX ZX) = () - rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) - -instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 - -ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l - -ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l - --- | @ssxEqType = 'testEquality'@. Provided for consistency. -ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -ssxEqType = testEquality - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' - -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n -ssxHead (StaticShX list) = listxHead list - -ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh - -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) - -ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) - -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShX' :: StaticShX sh -> Maybe (IShX sh) -ssxToShX' ZKX = Just ZSX -ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh -ssxToShX' (SUnknown _ :!% _) = Nothing - -ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKX -ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' - = SUnknown () :!% ssxReplicate n - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh - -ssxFromShape :: IShX sh -> StaticShX sh -ssxFromShape ZSX = ZKX -ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh - -ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) -ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n - - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShX :: [Maybe Nat] -> Constraint -class KnownShX sh where knownShX :: StaticShX sh -instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX -instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX - -withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX sh = withDict @(KnownShX sh) sh - - --- * Flattening - -type Flatten sh = Flatten' 1 sh - -type family Flatten' acc sh where - Flatten' acc '[] = Just acc - Flatten' acc (Nothing : sh) = Nothing - Flatten' acc (Just n : sh) = Flatten' (acc * n) sh - --- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) -ssxFlatten = go (SNat @1) - where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) - go acc ZKX = SKnown acc - go _ (SUnknown () :!% _) = SUnknown () - go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh - -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) -shxFlatten = go (SNat @1) - where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) - go acc ZSX = SKnown acc - go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) - go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh - - goUnknown :: Int -> IShX sh -> Int - goUnknown acc ZSX = acc - goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh - goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh - - --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i - fromList = listxFromList (knownShX @sh) - toList = listxToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where - type Item (IxX sh i) = i - fromList = IxX . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int - fromList = shxFromList (knownShX @sh) - toList = shxToList diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs deleted file mode 100644 index 736ced6..0000000 --- a/src/Data/Array/Mixed/Types.hs +++ /dev/null @@ -1,134 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Types ( - -- * Reified evidence of a type class - Dict(..), - - -- * Type-level naturals - pattern SZ, pattern SS, - fromSNat', sameNat', - snatPlus, snatMinus, snatMul, - snatSucc, - - -- * Type-level lists - type (++), - Replicate, - lemReplicateSucc, - MapJust, - Head, - Tail, - Init, - Last, - - -- * Unsafe - unsafeCoerceRefl, -) where - -import Data.Type.Equality -import Data.Proxy -import GHC.TypeLits -import GHC.TypeNats qualified as TN -import Unsafe.Coerce qualified - - --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) -sameNat' n@SNat m@SNat = sameNat n m - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - --- This should be a function in base -snatPlus :: SNat n -> SNat m -> SNat (n + m) -snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMinus :: SNat n -> SNat m -> SNat (n - m) -snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce - --- This should be a function in base -snatMul :: SNat n -> SNat m -> SNat (n * m) -snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - -type family Head l where - Head (x : _) = x - -type family Tail l where - Tail (_ : xs) = xs - -type family Init l where - Init (x : y : xs) = x : Init (y : xs) - Init '[x] = '[] - -type family Last l where - Last (x : y : xs) = Last (y : xs) - Last '[x] = x - - --- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to --- only typecheck for actual type equalities. One cannot, e.g. accidentally --- write this: --- --- @ --- foo :: Proxy a -> Proxy b -> a :~: b --- foo = unsafeCoerceRefl --- @ --- --- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', --- but would have resulted in interesting memory errors at runtime. -unsafeCoerceRefl :: a :~: b -unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs deleted file mode 100644 index 93484dc..0000000 --- a/src/Data/Array/Mixed/XArray.hs +++ /dev/null @@ -1,349 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.XArray where - -import Control.DeepSeq (NFData) -import Data.Array.Internal.RankedG qualified as ORG -import Data.Array.Internal.RankedS qualified as ORS -import Data.Array.Internal qualified as OI -import Data.Array.Ranked qualified as ORB -import Data.Array.RankedS qualified as S -import Data.Coerce -import Data.Foldable (toList) -import Data.Kind -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Type.Ord -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types - - -type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (Rank sh) a) - deriving (Show, Eq, Ord, Generic) - -instance NFData (XArray sh a) - - -shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh -shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKX [] = ZSX - go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shxToList sh) v) - -toVector :: Storable a => XArray sh a -> VS.Vector a -toVector (XArray arr) = S.toVector arr - --- | This allows observing the strides in the underlying orthotope array. This --- can be useful for optimisation, but should be considered an implementation --- detail: strides may change in new versions of this library without notice. -arrayStrides :: XArray sh a -> [Int] -arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides - -scalar :: Storable a => a -> XArray '[] a -scalar = XArray . S.scalar - --- | Will throw if the array does not have the casted-to shape. -cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> StaticShX sh' - -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -cast ssh1 sh2 ssh' (XArray arr) - | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (ssxFromShape sh2) ssh' - = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if shxToList arrsh == shxToList sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" - -unScalar :: Storable a => XArray '[] a -> a -unScalar (XArray a) = S.unScalar a - -replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a -replicate sh ssh' (XArray arr) - | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') - , Refl <- lemRankApp (ssxFromShape sh) ssh' - = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ - arr) - -replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -replicateScal sh x - | Dict <- lemKnownNatRank sh - = XArray (S.constant (shxToList sh) x) - -generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) - --- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shxShapeL sh) --- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) - -indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a -indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx - -index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' - -append :: forall n m sh a. Storable a - => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append ssh (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX ssh - = XArray (S.append a b) - --- | All arrays must have the same shape, except possibly for the outermost --- dimension. -concat :: Storable a - => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a -concat ssh l - | Dict <- lemKnownNatRankSSX ssh - = XArray (S.concatOuter (coerce (toList l))) - --- | If the prefix of the shape of the input array (@sh@) is empty (i.e. --- contains a zero), then there is no way to deduce the full shape of the output --- array (more precisely, the @sh2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the shape with zeros wherever we cannot deduce --- what it should be. --- --- For example, if: --- --- @ --- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] --- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float --- @ --- --- then: --- --- @ --- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float --- @ --- --- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ --- in this shape: we don't know if @f@ intended to return an array with shape 0 --- here (it probably didn't), but there is no better number to put here absent --- a subarray of the input to pass to @f@. --- --- In this particular case the fact that @sh@ is empty was evident from the --- type-level information, but the same situation occurs when @sh@ consists of --- @Nothing@s, and some of those happen to be zero at runtime. -rerank :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f xarr@(XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if any (== 0) (shxToList sh) - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a -> let XArray r = f (XArray a) in r) - arr) - -rerankTop :: forall sh1 sh2 sh a b. - (Storable a, Storable b) - => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b -rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh - --- | The caveat about empty arrays at @rerank@ applies here too. -rerank2 :: forall sh sh1 sh2 a b c. - (Storable a, Storable b, Storable c) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if any (== 0) (shxToList sh) - then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) - else case () of - () | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a b -> let XArray r = f (XArray a) (XArray b) in r) - arr1 arr2) - --- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) - => StaticShX sh - -> Perm is - -> XArray sh a - -> XArray (PermutePrefix is sh) a -transpose ssh perm (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) - , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm - , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList' perm) arr) - --- | The list argument gives indices into the original dimension list. --- --- The permutation (the list) must have length <= @n@. If it is longer, this --- function throws. -transposeUntyped :: forall n sh a. - SNat n -> StaticShX sh -> [Int] - -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a -transposeUntyped sn ssh perm (XArray arr) - | length perm <= fromSNat' sn - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) - = XArray (S.transpose perm arr) - | otherwise - = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" - -transpose2 :: forall sh1 sh2 a. - StaticShX sh1 -> StaticShX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a -sumFull _ (XArray arr) = - S.unScalar $ - liftO1 (numEltSum1Inner (SNat @0)) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr - -sumInner :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' arr - | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = shxFlatten sh' :$% ZSX - ssh'F = ssxFromShape sh'F - - go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a - go (XArray arr') - | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) - = XArray (liftO1 (numEltSum1Inner sn) arr') - - in go $ - transpose2 ssh'F ssh $ - reshapePartial ssh' ssh sh'F $ - transpose2 ssh ssh' $ - arr - -sumOuter :: forall sh sh' a. (Storable a, NumElt a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' arr - | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = shxFlatten sh :$% ZSX - in sumInner ssh' (ssxFromShape shF) $ - transpose2 (ssxFromShape shF) ssh' $ - reshapePartial ssh ssh' shF $ - arr - -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh - = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) - -toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = - case S.shapeL arr of - 0 : _ -> [] - _ -> coerce (ORB.toList (S.unravel arr)) - -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) - -toList1 :: Storable a => XArray '[n] a -> [a] -toList1 (XArray arr) = S.toList arr - --- | Throws if the given shape is not, in fact, empty. -empty :: forall sh a. Storable a => IShX sh -> XArray sh a -empty sh - | Dict <- lemKnownNatRank sh - , shxSize sh == 0 - = XArray (S.fromVector (shxToList sh) VS.empty) - | otherwise - = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh - -slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a -slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) - -sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a -sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) - -rev1 :: XArray (n : sh) a -> XArray (n : sh) a -rev1 (XArray arr) = XArray (S.rev [0] arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a -reshape ssh1 sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX ssh1 - , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shxToList sh2) arr) - --- | Throws if the given array and the target shape do not have the same number of elements. -reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a -reshapePartial ssh1 ssh' sh2 (XArray arr) - | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') - = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) - --- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). -iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a -iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) |