diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed.hs | |
parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 757 |
1 files changed, 43 insertions, 714 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 4ae89a1..0100ec8 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -30,300 +30,23 @@ module Data.Array.Mixed where import Control.DeepSeq (NFData(..)) import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB -import Data.Bifunctor (first) import Data.Coerce -import qualified Data.Foldable as Foldable -import Data.Functor.Const import Data.Kind -import Data.List (sort) -import Data.Monoid (Sum(..)) import Data.Proxy -import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) -import GHC.IsList (IsList) -import qualified GHC.IsList as IsList -import GHC.TypeError import GHC.TypeLits -import qualified GHC.TypeNats as TypeNats -import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -instance (forall n. Show (f n)) => Show (ListX sh f) where - showsPrec _ = showListX shows - -instance (forall n. NFData (f n)) => NFData (ListX sh f) where - rnf ZX = () - rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) -unconsListX ZX = Nothing - -fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -fmapListX _ ZX = ZX -fmapListX f (x ::% xs) = f x ::% fmapListX f xs - -foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -foldListX _ ZX = mempty -foldListX f (x ::% xs) = f x <> foldListX f xs - -lengthListX :: ListX sh f -> Int -lengthListX = getSum . foldListX (\_ -> Sum 1) - -snatLengthListX :: ListX sh f -> SNat (Rank sh) -snatLengthListX ZX = SNat -snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat - -showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -showListX f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListX sh' f -> ShowS - go _ ZX = id - go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listXToList :: ListX sh' (Const i) -> [i] -listXToList ZX = [] -listXToList (Const i ::% is) = i : listXToList is - - -type role IxX nominal representational -type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) - -pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i -pattern ZIX = IxX ZX - -pattern (:.%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) -infixr 3 :.% - -{-# COMPLETE ZIX, (:.%) #-} - -type IIxX sh = IxX sh Int - -instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = showListX (\(Const i) -> shows i) l - -instance Functor (IxX sh) where - fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) - -instance Foldable (IxX sh) where - foldMap f (IxX l) = foldListX (f . getConst) l - -instance NFData i => NFData (IxX sh i) - - -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) - -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where - rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x - -fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int SNat n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type role ShX nominal representational -type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} - -type IShX sh = ShX sh Int - -instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l - -instance Functor (ShX sh) where - fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) - -lengthShX :: ShX sh i -> Int -lengthShX (ShX l) = lengthListX l - -shXToList :: IShX sh -> [Int] -shXToList ZSX = [] -shXToList (smn :$% sh) = fromSMayNat' smn : shXToList sh - - --- | The part of a shape that is statically known. -type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) - -pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX - -pattern (:!%) - :: forall {sh1}. - forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) -infixr 3 :!% - -{-# COMPLETE ZKX, (:!%) #-} - -instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l - -lengthStaticShX :: StaticShX sh -> Int -lengthStaticShX (StaticShX l) = lengthListX l - -geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -geqStaticShX ZKX ZKX = Just Refl -geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') - | Just Refl <- sameNat n m - , Just Refl <- geqStaticShX sh sh' - = Just Refl -geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh') - | Just Refl <- geqStaticShX sh sh' - = Just Refl -geqStaticShX _ _ = Nothing - - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShX :: [Maybe Nat] -> Constraint -class KnownShX sh where knownShX :: StaticShX sh -instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX -instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX - - --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i - fromList topl = go (knownShX @sh) topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "IsList(ListX): Mismatched list length (type says " - ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listXToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where - type Item (IxX sh i) = i - fromList = IxX . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int - fromList topl = ShX (go (knownShX @sh) topl) - where - go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) - go ZKX [] = ZX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn ::% go sh is - | otherwise = error $ "IsList(ShX): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is - go _ _ = error $ "IsList(ShX): Mismatched list length (type says " - ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = shXToList - - -type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 - type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show, Eq, Generic) @@ -333,180 +56,6 @@ deriving instance (Ord a, Storable a) => Ord (XArray '[] a) instance NFData a => NFData (XArray sh a) -zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKX = ZIX -zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh - -zeroIxX' :: IShX sh -> IIxX sh -zeroIxX' ZSX = ZIX -zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh - --- This is a weird operation, so it has a long name -completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKX = ZSX -completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh -completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixAppend = coerce (listxAppend @_ @(Const i)) - -shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixDrop = coerce (listxDrop @(Const i) @(Const i)) - -shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) - -shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) - -shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) - -shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh - -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) - --- TODO: generalise all these things to arbitrary @i@ -shTail :: IShX (n : sh) -> IShX sh -shTail (_ :$% sh) = sh - -ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh - -shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKX idx = (ZSX, idx) -shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' - -shapeSize :: IShX sh -> Int -shapeSize ZSX = 1 -shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKX = Just ZSX -ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh -ssxToShape' (SUnknown _ :!% _) = Nothing - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerce Refl - -ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKX -ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' - = SUnknown () :!% ssxReplicate n - -fromLinearIdx :: IShX sh -> Int -> IIxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) - -toLinearIdx :: IShX sh -> IIxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) - -enumShape :: IShX sh -> [IIxX sh] -enumShape = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shapeLshape :: IShX sh -> S.ShapeL -shapeLshape ZSX = [] -shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh - -ssxLength :: StaticShX sh -> Int -ssxLength ZKX = 0 -ssxLength (_ :!% ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh - -type Flatten sh = Flatten' 1 sh - -type family Flatten' acc sh where - Flatten' acc '[] = Just acc - Flatten' acc (Nothing : sh) = Nothing - Flatten' acc (Just n : sh) = Flatten' (acc * n) sh - -flattenSSX :: StaticShX sh -> SMayNat () SNat (Flatten sh) -flattenSSX = go (SNat @1) - where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) - go acc ZKX = SKnown acc - go _ (SUnknown () :!% _) = SUnknown () - go acc (SKnown sn :!% sh) = go (mulSNat acc sn) sh - -flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) -flattenSh = go (SNat @1) - where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) - go acc ZSX = SKnown acc - go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) - go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh - - goUnknown :: Int -> IShX sh -> Int - goUnknown acc ZSX = acc - goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh - goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh - -staticShapeFrom :: IShX sh -> StaticShX sh -staticShapeFrom ZSX = ZKX -staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh - -lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRank ZSX = Dict -lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict - -lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX ZKX = Dict -lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) @@ -519,7 +68,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shapeLshape sh) v) + = XArray (S.fromVector (shxToList sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr @@ -527,23 +76,18 @@ toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar -eqShX :: IShX sh1 -> IShX sh2 -> Bool -eqShX ZSX ZSX = True -eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2 -eqShX _ _ = False - -- | Will throw if the array does not have the casted-to shape. cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> StaticShX sh' -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (staticShapeFrom sh2) ssh' + , Refl <- lemRankApp (ssxFromShape sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if eqShX arrsh sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in case shxEqual arrsh sh2 of + Just _ -> XArray arr + Nothing -> error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a @@ -551,24 +95,24 @@ unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') - , Refl <- lemRankApp (staticShapeFrom sh) ssh' - = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') + , Refl <- lemRankApp (ssxFromShape sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ arr) replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a replicateScal sh x | Dict <- lemKnownNatRank sh - = XArray (S.constant (shapeLshape sh) x) + = XArray (S.constant (shxToList sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shapeLshape sh) --- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr @@ -580,24 +124,6 @@ index xarr i = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - --- This should be a function in base -plusSNat :: SNat n -> SNat m -> SNat (n + m) -plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce - --- This should be a function in base -mulSNat :: SNat n -> SNat m -> SNat (n * m) -mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce - -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) - append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) @@ -639,9 +165,9 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if any (== 0) (shapeLshape sh) - then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 @@ -666,9 +192,9 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if any (== 0) (shapeLshape sh) - then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 @@ -678,211 +204,14 @@ rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) (\a b -> let XArray r = f (XArray a) (XArray b) in r) arr1 arr2) -type family Elem x l where - Elem x '[] = 'False - Elem x (x : _) = 'True - Elem x (_ : ys) = Elem x ys - -type family AllElem' as bs where - AllElem' '[] bs = 'True - AllElem' (a : as) bs = Elem a bs && AllElem' as bs - -type AllElem as bs = Assert (AllElem' as bs) - (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) - -type family Count i n where - Count n n = '[] - Count i n = i : Count (i + 1) n - -type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) - -type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh - -type family Permute is sh where - Permute '[] sh = '[] - Permute (i : is) sh = Index i sh : Permute is sh - -type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh - -data HList f list where - HNil :: HList f '[] - HCons :: f a -> HList f l -> HList f (a : l) -infixr 5 `HCons` -deriving instance (forall a. Show (f a)) => Show (HList f list) -deriving instance (forall a. Eq (f a)) => Eq (HList f list) - -foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m -foldHList _ HNil = mempty -foldHList f (x `HCons` l) = f x <> foldHList f l - -snatLengthHList :: HList f list -> SNat (Rank list) -snatLengthHList HNil = SNat -snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat - -permFromList :: [Int] -> (forall list. HList SNat list -> r) -> r -permFromList [] k = k HNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `HCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x - -permToList :: HList SNat list -> [Int] -permToList = foldHList (pure . fromSNat') - -type family TakeLen ref l where - TakeLen '[] l = '[] - TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs - -type family DropLen ref l where - DropLen '[] l = l - DropLen (_ : ref) (_ : xs) = DropLen ref xs - -lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is -lemRankPermute _ HNil = Refl -lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl - -lemRankDropLen :: forall is sh. (Rank is <= Rank sh) - => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKX HNil = Refl -lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!% _) HNil = Refl -lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerce Refl - -listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen HNil _ = ZX -listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen HNil sh = sh -listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape" - -listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f -listxPermute HNil _ = ZX -listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) - -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f -listxIndex _ _ SZ (n ::% _) rest = n ::% rest -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh rest -listxIndex _ _ _ ZX _ = error "Index into empty shape" - -listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) - -ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) -shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - --- TODO: test this thing more properly -invertPermutation :: HList SNat is - -> (forall is'. - Permutation is' - => HList SNat is' - -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) - -> r) - -> r -invertPermutation = \perm k -> - genPerm perm $ \(invperm :: HList SNat is') -> - let sn = snatLengthHList invperm - in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of - (Just Refl, Just Refl) -> - k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) - _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm - where - genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r - genPerm perm = - let permList = foldHList (pure . fromSNat) perm - in toHList $ map snd (sort (zip permList [0..])) - where - toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r - toHList [] k = k HNil - toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) - - lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True - lemElemCount _ _ = unsafeCoerce Refl - - lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n - lemCount _ _ = unsafeCoerce Refl - - lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True - lemElem _ _ = unsafeCoerce Refl - - provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) - provePerm1 _ _ HNil = Just (Refl) - provePerm1 p rtop@SNat (HCons sn@SNat perm) - | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - _ -> Nothing - | otherwise - = Nothing - - provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm - -> checkElem i perm - | otherwise -> Nothing - GTI -> error "unreachable" - where - checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True) - checkElem _ HNil = Nothing - checkElem i@SNat (HCons k@SNat perm :: HList SNat is') = - case sameNat i k of - Just Refl -> Just Refl - Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl - | otherwise -> Nothing - - provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh - -applyPermX :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f -applyPermX perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -applyPermIxX :: forall i is sh. HList SNat is -> IxX sh i -> IxX (PermutePrefix is sh) i -applyPermIxX = coerce (applyPermX @(Const i)) - -applyPermShX :: forall i is sh. HList SNat is -> ShX sh i -> ShX (PermutePrefix is sh) i -applyPermShX = coerce (applyPermX @(SMayNat i SNat)) - -class KnownNatList l where makeNatList :: HList SNat l -instance KnownNatList '[] where makeNatList = HNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList +class KnownNatList l where makeNatList :: Perm l +instance KnownNatList '[] where makeNatList = PNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList -- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) => StaticShX sh - -> HList SNat is + -> Perm is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) @@ -890,7 +219,7 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList perm) arr) + = XArray (S.transpose (permToList' perm) arr) -- | The list argument gives indices into the original dimension list. -- @@ -929,14 +258,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = flattenSh sh' :$% ZSX - ssh'F = staticShapeFrom sh'F + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShape sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = snatLengthListX (let StaticShX l = ssh in l) + , let sn = listxLengthSNat (let StaticShX l = ssh in l) = XArray (numEltSum1Inner sn arr') in go $ @@ -949,10 +278,10 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = flattenSh sh :$% ZSX - in sumInner ssh' (staticShapeFrom shF) $ - transpose2 (staticShapeFrom shF) ssh' $ + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShape shF) $ + transpose2 (ssxFromShape shF) ssh' $ reshapePartial ssh ssh' shF $ arr @@ -988,7 +317,7 @@ toList1 (XArray arr) = S.toList arr empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh - = XArray (S.constant (shapeLshape sh) + = XArray (S.constant (shxToList sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a @@ -1005,14 +334,14 @@ reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray s reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shapeLshape sh2) arr) + = XArray (S.reshape (shxToList sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') - = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthStaticShX ssh1) (S.shapeL arr)) arr) + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a |