diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data | |
parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Mixed.hs | 757 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs (renamed from src/Data/Array/Nested/Internal/Arith.hs) | 6 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs (renamed from src/Data/Array/Nested/Internal/Arith/Foreign.hs) | 4 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs (renamed from src/Data/Array/Nested/Internal/Arith/Lists.hs) | 4 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs (renamed from src/Data/Array/Nested/Internal/Arith/Lists/TH.hs) | 2 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 47 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 252 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 455 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Types.hs | 110 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 10 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 326 |
11 files changed, 1087 insertions, 886 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 4ae89a1..0100ec8 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -30,300 +30,23 @@ module Data.Array.Mixed where import Control.DeepSeq (NFData(..)) import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB -import Data.Bifunctor (first) import Data.Coerce -import qualified Data.Foldable as Foldable -import Data.Functor.Const import Data.Kind -import Data.List (sort) -import Data.Monoid (Sum(..)) import Data.Proxy -import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) -import GHC.IsList (IsList) -import qualified GHC.IsList as IsList -import GHC.TypeError import GHC.TypeLits -import qualified GHC.TypeNats as TypeNats -import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -instance (forall n. Show (f n)) => Show (ListX sh f) where - showsPrec _ = showListX shows - -instance (forall n. NFData (f n)) => NFData (ListX sh f) where - rnf ZX = () - rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) -unconsListX ZX = Nothing - -fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -fmapListX _ ZX = ZX -fmapListX f (x ::% xs) = f x ::% fmapListX f xs - -foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -foldListX _ ZX = mempty -foldListX f (x ::% xs) = f x <> foldListX f xs - -lengthListX :: ListX sh f -> Int -lengthListX = getSum . foldListX (\_ -> Sum 1) - -snatLengthListX :: ListX sh f -> SNat (Rank sh) -snatLengthListX ZX = SNat -snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat - -showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -showListX f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListX sh' f -> ShowS - go _ ZX = id - go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listXToList :: ListX sh' (Const i) -> [i] -listXToList ZX = [] -listXToList (Const i ::% is) = i : listXToList is - - -type role IxX nominal representational -type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) - -pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i -pattern ZIX = IxX ZX - -pattern (:.%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) -infixr 3 :.% - -{-# COMPLETE ZIX, (:.%) #-} - -type IIxX sh = IxX sh Int - -instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = showListX (\(Const i) -> shows i) l - -instance Functor (IxX sh) where - fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) - -instance Foldable (IxX sh) where - foldMap f (IxX l) = foldListX (f . getConst) l - -instance NFData i => NFData (IxX sh i) - - -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) - -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where - rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x - -fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int SNat n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type role ShX nominal representational -type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} - -type IShX sh = ShX sh Int - -instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l - -instance Functor (ShX sh) where - fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) - -lengthShX :: ShX sh i -> Int -lengthShX (ShX l) = lengthListX l - -shXToList :: IShX sh -> [Int] -shXToList ZSX = [] -shXToList (smn :$% sh) = fromSMayNat' smn : shXToList sh - - --- | The part of a shape that is statically known. -type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) - -pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX - -pattern (:!%) - :: forall {sh1}. - forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) -infixr 3 :!% - -{-# COMPLETE ZKX, (:!%) #-} - -instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l - -lengthStaticShX :: StaticShX sh -> Int -lengthStaticShX (StaticShX l) = lengthListX l - -geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -geqStaticShX ZKX ZKX = Just Refl -geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') - | Just Refl <- sameNat n m - , Just Refl <- geqStaticShX sh sh' - = Just Refl -geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh') - | Just Refl <- geqStaticShX sh sh' - = Just Refl -geqStaticShX _ _ = Nothing - - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShX :: [Maybe Nat] -> Constraint -class KnownShX sh where knownShX :: StaticShX sh -instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX -instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX - - --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i - fromList topl = go (knownShX @sh) topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "IsList(ListX): Mismatched list length (type says " - ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listXToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where - type Item (IxX sh i) = i - fromList = IxX . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int - fromList topl = ShX (go (knownShX @sh) topl) - where - go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) - go ZKX [] = ZX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn ::% go sh is - | otherwise = error $ "IsList(ShX): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is - go _ _ = error $ "IsList(ShX): Mismatched list length (type says " - ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = shXToList - - -type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 - type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show, Eq, Generic) @@ -333,180 +56,6 @@ deriving instance (Ord a, Storable a) => Ord (XArray '[] a) instance NFData a => NFData (XArray sh a) -zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKX = ZIX -zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh - -zeroIxX' :: IShX sh -> IIxX sh -zeroIxX' ZSX = ZIX -zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh - --- This is a weird operation, so it has a long name -completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKX = ZSX -completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh -completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixAppend = coerce (listxAppend @_ @(Const i)) - -shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixDrop = coerce (listxDrop @(Const i) @(Const i)) - -shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) - -shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) - -shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) - -shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh - -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) - --- TODO: generalise all these things to arbitrary @i@ -shTail :: IShX (n : sh) -> IShX sh -shTail (_ :$% sh) = sh - -ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh - -shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKX idx = (ZSX, idx) -shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' - -shapeSize :: IShX sh -> Int -shapeSize ZSX = 1 -shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKX = Just ZSX -ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh -ssxToShape' (SUnknown _ :!% _) = Nothing - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerce Refl - -ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKX -ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' - = SUnknown () :!% ssxReplicate n - -fromLinearIdx :: IShX sh -> Int -> IIxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) - -toLinearIdx :: IShX sh -> IIxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) - -enumShape :: IShX sh -> [IIxX sh] -enumShape = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shapeLshape :: IShX sh -> S.ShapeL -shapeLshape ZSX = [] -shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh - -ssxLength :: StaticShX sh -> Int -ssxLength ZKX = 0 -ssxLength (_ :!% ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh - -type Flatten sh = Flatten' 1 sh - -type family Flatten' acc sh where - Flatten' acc '[] = Just acc - Flatten' acc (Nothing : sh) = Nothing - Flatten' acc (Just n : sh) = Flatten' (acc * n) sh - -flattenSSX :: StaticShX sh -> SMayNat () SNat (Flatten sh) -flattenSSX = go (SNat @1) - where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) - go acc ZKX = SKnown acc - go _ (SUnknown () :!% _) = SUnknown () - go acc (SKnown sn :!% sh) = go (mulSNat acc sn) sh - -flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) -flattenSh = go (SNat @1) - where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) - go acc ZSX = SKnown acc - go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) - go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh - - goUnknown :: Int -> IShX sh -> Int - goUnknown acc ZSX = acc - goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh - goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh - -staticShapeFrom :: IShX sh -> StaticShX sh -staticShapeFrom ZSX = ZKX -staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh - -lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRank ZSX = Dict -lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict - -lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX ZKX = Dict -lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) @@ -519,7 +68,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shapeLshape sh) v) + = XArray (S.fromVector (shxToList sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr @@ -527,23 +76,18 @@ toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar -eqShX :: IShX sh1 -> IShX sh2 -> Bool -eqShX ZSX ZSX = True -eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2 -eqShX _ _ = False - -- | Will throw if the array does not have the casted-to shape. cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> StaticShX sh' -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (staticShapeFrom sh2) ssh' + , Refl <- lemRankApp (ssxFromShape sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if eqShX arrsh sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in case shxEqual arrsh sh2 of + Just _ -> XArray arr + Nothing -> error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a @@ -551,24 +95,24 @@ unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') - , Refl <- lemRankApp (staticShapeFrom sh) ssh' - = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') + , Refl <- lemRankApp (ssxFromShape sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ arr) replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a replicateScal sh x | Dict <- lemKnownNatRank sh - = XArray (S.constant (shapeLshape sh) x) + = XArray (S.constant (shxToList sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shapeLshape sh) --- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr @@ -580,24 +124,6 @@ index xarr i = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - --- This should be a function in base -plusSNat :: SNat n -> SNat m -> SNat (n + m) -plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce - --- This should be a function in base -mulSNat :: SNat n -> SNat m -> SNat (n * m) -mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce - -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) - append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) @@ -639,9 +165,9 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if any (== 0) (shapeLshape sh) - then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 @@ -666,9 +192,9 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if any (== 0) (shapeLshape sh) - then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 @@ -678,211 +204,14 @@ rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) (\a b -> let XArray r = f (XArray a) (XArray b) in r) arr1 arr2) -type family Elem x l where - Elem x '[] = 'False - Elem x (x : _) = 'True - Elem x (_ : ys) = Elem x ys - -type family AllElem' as bs where - AllElem' '[] bs = 'True - AllElem' (a : as) bs = Elem a bs && AllElem' as bs - -type AllElem as bs = Assert (AllElem' as bs) - (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) - -type family Count i n where - Count n n = '[] - Count i n = i : Count (i + 1) n - -type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) - -type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh - -type family Permute is sh where - Permute '[] sh = '[] - Permute (i : is) sh = Index i sh : Permute is sh - -type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh - -data HList f list where - HNil :: HList f '[] - HCons :: f a -> HList f l -> HList f (a : l) -infixr 5 `HCons` -deriving instance (forall a. Show (f a)) => Show (HList f list) -deriving instance (forall a. Eq (f a)) => Eq (HList f list) - -foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m -foldHList _ HNil = mempty -foldHList f (x `HCons` l) = f x <> foldHList f l - -snatLengthHList :: HList f list -> SNat (Rank list) -snatLengthHList HNil = SNat -snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat - -permFromList :: [Int] -> (forall list. HList SNat list -> r) -> r -permFromList [] k = k HNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `HCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x - -permToList :: HList SNat list -> [Int] -permToList = foldHList (pure . fromSNat') - -type family TakeLen ref l where - TakeLen '[] l = '[] - TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs - -type family DropLen ref l where - DropLen '[] l = l - DropLen (_ : ref) (_ : xs) = DropLen ref xs - -lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is -lemRankPermute _ HNil = Refl -lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl - -lemRankDropLen :: forall is sh. (Rank is <= Rank sh) - => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKX HNil = Refl -lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!% _) HNil = Refl -lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerce Refl - -listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen HNil _ = ZX -listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen HNil sh = sh -listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape" - -listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f -listxPermute HNil _ = ZX -listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) - -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f -listxIndex _ _ SZ (n ::% _) rest = n ::% rest -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh rest -listxIndex _ _ _ ZX _ = error "Index into empty shape" - -listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) - -ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) -shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - --- TODO: test this thing more properly -invertPermutation :: HList SNat is - -> (forall is'. - Permutation is' - => HList SNat is' - -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) - -> r) - -> r -invertPermutation = \perm k -> - genPerm perm $ \(invperm :: HList SNat is') -> - let sn = snatLengthHList invperm - in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of - (Just Refl, Just Refl) -> - k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) - _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm - where - genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r - genPerm perm = - let permList = foldHList (pure . fromSNat) perm - in toHList $ map snd (sort (zip permList [0..])) - where - toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r - toHList [] k = k HNil - toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) - - lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True - lemElemCount _ _ = unsafeCoerce Refl - - lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n - lemCount _ _ = unsafeCoerce Refl - - lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True - lemElem _ _ = unsafeCoerce Refl - - provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) - provePerm1 _ _ HNil = Just (Refl) - provePerm1 p rtop@SNat (HCons sn@SNat perm) - | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - _ -> Nothing - | otherwise - = Nothing - - provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm - -> checkElem i perm - | otherwise -> Nothing - GTI -> error "unreachable" - where - checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True) - checkElem _ HNil = Nothing - checkElem i@SNat (HCons k@SNat perm :: HList SNat is') = - case sameNat i k of - Just Refl -> Just Refl - Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl - | otherwise -> Nothing - - provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh - -applyPermX :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f -applyPermX perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -applyPermIxX :: forall i is sh. HList SNat is -> IxX sh i -> IxX (PermutePrefix is sh) i -applyPermIxX = coerce (applyPermX @(Const i)) - -applyPermShX :: forall i is sh. HList SNat is -> ShX sh i -> ShX (PermutePrefix is sh) i -applyPermShX = coerce (applyPermX @(SMayNat i SNat)) - -class KnownNatList l where makeNatList :: HList SNat l -instance KnownNatList '[] where makeNatList = HNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList +class KnownNatList l where makeNatList :: Perm l +instance KnownNatList '[] where makeNatList = PNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList -- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) => StaticShX sh - -> HList SNat is + -> Perm is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) @@ -890,7 +219,7 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList perm) arr) + = XArray (S.transpose (permToList' perm) arr) -- | The list argument gives indices into the original dimension list. -- @@ -929,14 +258,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = flattenSh sh' :$% ZSX - ssh'F = staticShapeFrom sh'F + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShape sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = snatLengthListX (let StaticShX l = ssh in l) + , let sn = listxLengthSNat (let StaticShX l = ssh in l) = XArray (numEltSum1Inner sn arr') in go $ @@ -949,10 +278,10 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = flattenSh sh :$% ZSX - in sumInner ssh' (staticShapeFrom shF) $ - transpose2 (staticShapeFrom shF) ssh' $ + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShape shF) $ + transpose2 (ssxFromShape shF) ssh' $ reshapePartial ssh ssh' shF $ arr @@ -988,7 +317,7 @@ toList1 (XArray arr) = S.toList arr empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh - = XArray (S.constant (shapeLshape sh) + = XArray (S.constant (shxToList sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a @@ -1005,14 +334,14 @@ reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray s reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shapeLshape sh2) arr) + = XArray (S.reshape (shxToList sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') - = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthStaticShX ssh1) (S.shapeL arr)) arr) + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 95fcfcf..cf6820b 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -6,7 +6,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Arith where +module Data.Array.Mixed.Internal.Arith where import Control.Monad (forM, guard) import qualified Data.Array.Internal as OI @@ -24,8 +24,8 @@ import GHC.TypeLits import Language.Haskell.TH import System.IO.Unsafe -import Data.Array.Nested.Internal.Arith.Foreign -import Data.Array.Nested.Internal.Arith.Lists +import Data.Array.Mixed.Internal.Arith.Foreign +import Data.Array.Mixed.Internal.Arith.Lists liftVEltwise1 :: Storable a diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index ac83188..6fc7229 100644 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -1,6 +1,6 @@ {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Foreign where +module Data.Array.Mixed.Internal.Arith.Foreign where import Control.Monad import Data.Int @@ -9,7 +9,7 @@ import Foreign.C.Types import Foreign.Ptr import Language.Haskell.TH -import Data.Array.Nested.Internal.Arith.Lists +import Data.Array.Mixed.Internal.Arith.Lists $(fmap concat . forM typesList $ \arithtype -> do diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs index ce2836d..a284bc1 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -1,12 +1,12 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Lists where +module Data.Array.Mixed.Internal.Arith.Lists where import Data.Char import Data.Int import Language.Haskell.TH -import Data.Array.Nested.Internal.Arith.Lists.TH +import Data.Array.Mixed.Internal.Arith.Lists.TH data ArithType = ArithType diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs index 7142dfa..8b7d05f 100644 --- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -1,5 +1,5 @@ {-# LANGUAGE TemplateHaskellQuotes #-} -module Data.Array.Nested.Internal.Arith.Lists.TH where +module Data.Array.Mixed.Internal.Arith.Lists.TH where import Control.Monad import Control.Monad.IO.Class diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs new file mode 100644 index 0000000..30ec9c0 --- /dev/null +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DataKinds #-} +module Data.Array.Mixed.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $ + lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $ + lemRankApp ssh1 ssh2 + where + lem :: proxy a -> proxy b -> proxy c + -> c :~: b + a + -> b + a :~: c + lem _ _ _ Refl = Refl + + lem2 :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem2 _ _ _ Refl = Refl + +lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this + +lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs new file mode 100644 index 0000000..2710018 --- /dev/null +++ b/src/Data/Array/Mixed/Permutation.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.TypeError +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' +-- for code that implements this. +data Perm list where + PNil :: Perm '[] + PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +permLengthSNat :: Perm list -> SNat (Rank list) +permLengthSNat PNil = SNat +permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromList xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + + +-- ** Applying permutations + +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f +listxIndex _ _ SZ (n ::% _) rest = n ::% rest +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh rest +listxIndex _ _ _ ZX _ = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +-- TODO: test this thing more properly +permInverse :: Perm is + -> (forall is'. + IsPermutation is' + => Perm is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +permInverse = \perm k -> + genPerm perm $ \(invperm :: Perm is') -> + let sn = permLengthSNat invperm + in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of + (Just Refl, Just Refl) -> + k invperm + (\ssh -> case provePermInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm + where + genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r + genPerm perm = + let permList = permToList' perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r + toHList [] k = k PNil + toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + + lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just (Refl) + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + + provePermInverse :: Perm is -> Perm is' -> StaticShX sh + -> Maybe (Permute is' (Permute is sh) :~: sh) + provePermInverse perm perminv ssh = + ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where + MapSucc '[] = '[] + MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc + where + permMapSucc :: Perm l -> Perm (MapSucc l) + permMapSucc PNil = PNil + permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l + -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs new file mode 100644 index 0000000..a16da76 --- /dev/null +++ b/src/Data/Array/Mixed/Shape.hs @@ -0,0 +1,455 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import qualified Data.Foldable as Foldable +import Data.Functor.Const +import Data.Kind (Type, Constraint) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import qualified GHC.IsList as IsList +import GHC.TypeLits + +import Data.Array.Mixed.Types +import Data.Coerce +import Data.Bifunctor (first) + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxLengthSNat :: ListX sh f -> SNat (Rank sh) +listxLengthSNat ZX = SNat +listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short + + +-- * Mixed indices + +-- | This is a newtype over 'ListX'. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +type IIxX sh = IxX sh Int + +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +-- | This is more than @geq@: it also checks that the integers (the unknown +-- dimensions) are the same. +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ :$% sh) = sh + +shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i +shxTakeSSX _ = flip go + where + go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i + go ZKX _ = ZSX + go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ +-- class of the @some@ package. +ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxGeq ZKX ZKX = Just Refl +ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') + | Just Refl <- sameNat n m + , Just Refl <- ssxGeq sh sh' + = Just Refl +ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') + | Just Refl <- ssxGeq sh sh' + = Just Refl +ssxGeq _ _ = Nothing + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShX' :: StaticShX sh -> Maybe (IShX sh) +ssxToShX' ZKX = Just ZSX +ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh +ssxToShX' (SUnknown _ :!% _) = Nothing + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh + +ssxFromShape :: IShX sh -> StaticShX sh +ssxFromShape ZSX = ZKX +ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList topl = go (knownShX @sh) topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "IsList(ListX): Mismatched list length (type says " + ++ show (ssxLength (knownShX @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList topl = ShX (go (knownShX @sh) topl) + where + go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) + go ZKX [] = ZX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn ::% go sh is + | otherwise = error $ "IsList(ShX): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is + go _ _ = error $ "IsList(ShX): Mismatched list length (type says " + ++ show (ssxLength (knownShX @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shxToList diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs new file mode 100644 index 0000000..d77513f --- /dev/null +++ b/src/Data/Array/Mixed/Types.hs @@ -0,0 +1,110 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Types ( + -- * Reified evidence of a type class + Dict(..), + + -- * Type-level naturals + pattern SZ, pattern SS, + fromSNat', + snatPlus, snatMul, + + -- * Type-level lists + type (++), + lemAppNil, + lemAppAssoc, + Replicate, + lemReplicateSucc, + + -- * Unsafe + unsafeCoerceRefl, +) where + +import Data.Type.Equality +import Data.Proxy +import GHC.TypeLits +import qualified GHC.TypeNats as TN +import qualified Unsafe.Coerce + + +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a + +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + +-- This should be a function in base +snatPlus :: SNat n -> SNat m -> SNat (n + m) +snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + + +-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to +-- only typecheck for actual type equalities. One cannot, e.g. accidentally +-- write this: +-- +-- @ +-- foo :: Proxy a -> Proxy b -> a :~: b +-- foo = unsafeCoerceRefl +-- @ +-- +-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', +-- but would have resulted in interesting memory errors at runtime. +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index e3af0ee..1a4e094 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -58,8 +58,9 @@ module Data.Array.Nested ( type (++), Storable, SNat, pattern SNat, - HList, - Permutation, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, KnownNatList(..), listSToList, shSToList, @@ -69,7 +70,10 @@ module Data.Array.Nested ( import Prelude hiding (mappend) import Data.Array.Mixed +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types import Data.Array.Nested.Internal -import Data.Array.Nested.Internal.Arith import Foreign.Storable import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 712c5f1..0870789 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -60,7 +60,11 @@ import Unsafe.Coerce import Data.Array.Mixed import qualified Data.Array.Mixed as X -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Types -- Invariant in the API @@ -123,19 +127,19 @@ lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) -lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate SZ = Refl lemRankReplicate (SS (n :: SNat nm1)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 + | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 , Refl <- lemRankReplicate n = Refl -lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh lemRankMapJust ZSS = Refl lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl @@ -146,9 +150,9 @@ lemReplicatePlusApp sn _ _ = go sn go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- X.lemReplicateSucc @a @n'm1 + | Refl <- lemReplicateSucc @a @n'm1 , Refl <- go n - = sym (X.lemReplicateSucc @a @(n'm1 + m)) + = sym (lemReplicateSucc @a @(n'm1 + m)) lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl @@ -156,17 +160,17 @@ lemLeqPlus _ _ _ = Refl lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True lemLeqSuccSucc _ _ = unsafeCoerce Refl -lemDropLenApp :: X.Rank l1 <= X.Rank l2 +lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest - -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) lemDropLenApp _ _ _ = unsafeCoerce Refl -lemTakeLenApp :: X.Rank l1 <= X.Rank l2 +lemTakeLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest - -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) lemTakeLenApp _ _ _ = unsafeCoerce Refl -srankSh :: ShX sh f -> SNat (X.Rank sh) +srankSh :: ShX sh f -> SNat (Rank sh) srankSh ZSX = SNat srankSh (_ :$% sh) | SNat <- srankSh sh = SNat @@ -585,11 +589,11 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) - => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a -- ====== PRIVATE METHODS ====== -- @@ -635,20 +639,20 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr) + in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) mlift ssh2 f (M_Primitive _ a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result @@ -657,36 +661,36 @@ instance Storable a => Elt (Primitive a) where -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - , Refl <- X.lemAppNil @sh3 + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result - mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' - in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) + let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) mtranspose perm (M_Primitive sh arr) = - M_Primitive (X.shPermutePrefix perm sh) - (X.transpose (X.staticShapeFrom sh) perm arr) + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShape sh) perm arr) mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (X.staticShapeFrom sh') arr - offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh)) - VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr) + let arrsh = X.shape (ssxFromShape sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v @@ -701,7 +705,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArray sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -755,7 +759,7 @@ instance Elt a => Elt (Mixed sh' a) where -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr)) + = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) i = mindexPartial arr i @@ -763,8 +767,8 @@ instance Elt a => Elt (Mixed sh' a) where mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -773,95 +777,95 @@ instance Elt a => Elt (Mixed sh' a) where M_Nest (SUnknown (length l) :$% mshape arr) (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr) + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (X.ssxAppend ssh2 ssh') f' arr - (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result) + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr))) + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (X.ssxAppend ssh' sshT) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result) + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1))) + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (X.ssxAppend ssh' sshT) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) - mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) mcast ssh1 sh2 _ (M_Nest sh1T arr) - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T - in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) - => HList SNat is -> Mixed sh (Mixed sh' a) - -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = X.shDropSh @sh @sh' (mshape arr) sh - , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') - , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) - , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + | let sh' = shxDropSh @sh @sh' (mshape arr) sh + , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (X.shPermutePrefix perm sh) + = M_Nest (shxPermutePrefix perm sh) (mtranspose perm arr) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) + memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) where sh' = mshape example - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -- | Create an array given a size and a function that computes the element at a @@ -882,10 +886,10 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case X.enumShape sh of +mgenerate sh f = case shxEnum sh of [] -> memptyArray sh firstidx : restidxs -> - let firstelem = f (X.zeroIxX' sh) + let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem in if mshapeTreeEmpty (Proxy @a) shapetree then memptyArray sh @@ -905,28 +909,28 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) + in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 where sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 - ssh = X.staticShapeFrom sh - snm :: SMayNat () SNat (X.AddMaybe n m) + ssh = ssxFromShape sh + snm :: SMayNat () SNat (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (X.plusSNat n m) + (SKnown n, SKnown m) -> SKnown (snatPlus n m) f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b - f ssh' = X.append (X.ssxAppend ssh ssh') + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) @@ -971,9 +975,9 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shDropSSX sh ssh - in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2) + let sh1 = shxDropSSX sh ssh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) @@ -988,10 +992,10 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a mreplicate sh arr = - let ssh' = X.staticShapeFrom (mshape arr) - in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') + let ssh' = ssxFromShape (mshape arr) + in mlift (ssxAppend (ssxFromShape sh) ssh') (\(sshT :: StaticShX shT) -> - case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr @@ -1005,18 +1009,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' arr = - mlift (X.staticShapeFrom sh') - (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh') + mlift (ssxFromShape sh') + (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') arr miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a @@ -1095,26 +1099,26 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where log1pexp = mliftNumElt1 floatEltLog1pexp log1mexp = mliftNumElt1 floatEltLog1mexp -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked arr - | Refl <- X.lemAppNil @sh - , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat)) + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) , Refl <- lemRankReplicate (srankSh (mshape arr)) - = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) where - convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing) + convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) convSh ZSX = ZSX convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) = SUnknown (fromSMayNat' smn) :$% convSh sh -mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh') +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a mcastToShaped arr targetsh - | Refl <- X.lemAppNil @sh - , Refl <- X.lemAppNil @(MapJust sh') + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(MapJust sh') , Refl <- lemRankMapJust targetsh - = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -1418,7 +1422,7 @@ zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR zeroIxR (SS n) = 0 :.: zeroIxR n -ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) +ixCvtXR :: IIxX sh -> IIxR (Rank sh) ixCvtXR ZIX = ZIR ixCvtXR (n :.% idx) = n :.: ixCvtXR idx @@ -1429,7 +1433,7 @@ shCvtXR' ZSX = shCvtXR' (n :$% (idx :: IShX sh)) | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = castWith (subst2 (lem1 @sh Refl)) - (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) where lem1 :: forall sh' n' k. k : sh' :~: Replicate n' Nothing @@ -1443,13 +1447,13 @@ shCvtXR' (n :$% (idx :: IShX sh)) ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.% ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shCvtRX idx) shapeSizeR :: IShR n -> Int @@ -1506,7 +1510,7 @@ rsumOuter1P :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (NumElt a, PrimElt a) @@ -1559,7 +1563,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- snatFromShR (rshape arr1) , Dict <- lemKnownReplicate sn - , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -1582,7 +1586,7 @@ rtoVector = coerce mtoVector rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a @@ -1593,7 +1597,7 @@ rfromList1Prim l = Ranked (mfromList1Prim l) rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] @@ -1677,7 +1681,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = rlift (snatFromShR (rshape arr)) (\_ -> X.sliceU i n) arr @@ -1686,7 +1690,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 arr = rlift (snatFromShR (rshape arr)) (\(_ :: StaticShX sh') -> - case X.lemReplicateSucc @(Nothing @Nat) @n of + case lemReplicateSucc @(Nothing @Nat) @n of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) arr @@ -1707,12 +1711,12 @@ rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) -rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) , Refl <- lemRankMapJust targetsh @@ -1809,7 +1813,7 @@ shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh shapeSizeS :: ShS sh -> Int shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh +shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh @@ -1838,14 +1842,14 @@ slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -1855,28 +1859,28 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive -lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) -lemCommMapJustTakeLen HNil _ = Refl -lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl -lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" +lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemCommMapJustTakeLen PNil _ = Refl +lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" -lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) -lemCommMapJustDropLen HNil _ = Refl -lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl -lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" +lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemCommMapJustDropLen PNil _ = Refl +lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" -lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) lemCommMapJustIndex SZ (_ :$$ _) = Refl lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) | Refl <- lemCommMapJustIndex i sh - , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = Refl lemCommMapJustIndex _ ZSS = error "Index of empty" -lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) -lemCommMapJustPermute HNil _ = Refl -lemCommMapJustPermute (i `HCons` is) sh +lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemCommMapJustPermute PNil _ = Refl +lemCommMapJustPermute (i `PCons` is) sh | Refl <- lemCommMapJustPermute is sh , Refl <- lemCommMapJustIndex i sh = Refl @@ -1885,53 +1889,53 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f -listsTakeLen HNil _ = ZS -listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh -listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLen PNil _ = ZS +listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh +listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLen HNil sh = sh -listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh -listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLen PNil sh = sh +listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh +listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f -listsPermute HNil _ = ZS -listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest - | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh rest listsIndex _ _ _ ZS _ = error "Index into empty shape" -shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLen = coerce (listsTakeLen @SNat) -shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT) shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) -applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) -applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i applyPermIxS = coerce (applyPermS @(Const i)) -applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) applyPermShS = coerce (applyPermS @SNat) -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) - => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemCommMapJustTakeLen perm (sshape sarr) , Refl <- lemCommMapJustDropLen perm (sshape sarr) , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) + , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a @@ -1969,7 +1973,7 @@ stoList1 = map sunScalar . stoListOuter sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromListPrim sn l - | Refl <- X.lemAppNil @'[Just n] + | Refl <- lemAppNil @'[Just n] = let ssh = SUnknown () :!% ZKX xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr @@ -1989,7 +1993,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemCommMapJustApp sh (Proxy @sh1) , Refl <- lemCommMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh)))) + = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) (shCvtSX sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) @@ -2033,12 +2037,12 @@ sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) -stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr |