diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape.hs')
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 88 |
1 files changed, 52 insertions, 36 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 2f35ff9..852dd5e 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -31,7 +31,6 @@ import Data.Functor.Const import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality import GHC.Exts (withDict) import GHC.Generics (Generic) @@ -146,9 +145,9 @@ listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh @@ -172,7 +171,7 @@ listxZipWith f (i ::% is) (j ::% js) = -- * Mixed indices --- | This is a newtype over 'ListX'. +-- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) @@ -191,6 +190,8 @@ infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxX sh = IxX sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES @@ -234,7 +235,7 @@ ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixxAppend = coerce (listxAppend @_ @(Const i)) -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i ixxDrop = coerce (listxDrop @(Const i) @(Const i)) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i @@ -243,6 +244,11 @@ ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js @@ -390,10 +396,10 @@ shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int +shxFromList :: StaticShX sh -> [Int] -> IShX sh shxFromList topssh topl = go topssh topl where - go :: StaticShX sh' -> [Int] -> ShX sh' Int + go :: StaticShX sh' -> [Int] -> IShX sh' go ZKX [] = ZSX go (SKnown sn :!% sh) (i : is) | i == fromSNat' sn = SKnown sn :$% go sh is @@ -408,11 +414,18 @@ shxToList :: IShX sh -> [Int] shxToList ZSX = [] shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" + -- | This may fail if @sh@ has @Nothing@s in it. -shxFromSSX' :: StaticShX sh -> Maybe (IShX sh) -shxFromSSX' ZKX = Just ZSX -shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh -shxFromSSX' (SUnknown _ :!% _) = Nothing +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) @@ -423,13 +436,13 @@ shxHead (ShX list) = listxHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i @@ -438,12 +451,9 @@ shxInit = coerce (listxInit @(SMayNat i SNat)) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) shxLast = coerce (listxLast @(SMayNat i SNat)) -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) -> ShX sh i -> ShX sh j -> ShX sh k @@ -456,7 +466,7 @@ shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) @@ -467,17 +477,17 @@ shxEnum = \sh -> go sh id [] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" @@ -537,9 +547,15 @@ ssxHead (StaticShX list) = listxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) + +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) + ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) ssxInit = coerce (listxInit @(SMayNat () SNat)) @@ -552,11 +568,11 @@ ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) -ssxFromShX :: IShX sh -> StaticShX sh +ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh @@ -574,7 +590,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX k = withDict @(KnownShX sh) k +withKnownShX = withDict @(KnownShX sh) -- * Flattening |