diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 55 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 20 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 20 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 7 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 17 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Types.hs | 7 | 
6 files changed, 107 insertions, 19 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 861bf20..fd59ba6 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -7,11 +7,18 @@  {-# LANGUAGE TypeAbstractions #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}  module Data.Array.Nested.Convert ( -  -- * Shape/index/list casting functions +  -- * Shape\/index\/list casting functions +  -- ** To ranked    ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, -  ixsFromIxX, shsFromShX, +  listrCast, ixrCast, shrCast, +  -- ** To shaped +  ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', shsFromShX, shsFromSSX, +  ixsCast, +  -- ** To mixed    ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, +  ixxCast, shxCast, shxCast',    -- * Array conversions    convert, @@ -57,24 +64,50 @@ shrFromShS ZSS = ZSR  shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh  -- shrFromShX re-exported -  -- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported  -- * To shaped --- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh --- ixsFromIxR = \ix -> go ix _ ---   where ---     go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r ---     go ZIR k = k ZIS ---     go (i :.: ix) k = go ix (i :.$) +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ _ = error "unreachable" +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS.  ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i  ixsFromIxX ZSS ZIX = ZIS  ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" +  -- shsFromShX re-exported +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported +  -- * To mixed  ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i @@ -97,6 +130,10 @@ shxFromShS :: ShS sh -> IShX (MapJust sh)  shxFromShS ZSS = ZSX  shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported +  -- * Array conversions diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 2ee3600..e63277f 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -244,6 +244,11 @@ ixxInit = coerce (listxInit @(Const i))  ixxLast :: forall n sh i. IxX (n : sh) i -> i  ixxLast = coerce (listxLast @(Const i)) +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" +  ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j)  ixxZip ZIX ZIX = ZIX  ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js @@ -409,11 +414,18 @@ shxToList :: IShX sh -> [Int]  shxToList ZSX = []  shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) +  | Refl <- lemMapJustCons @sh Refl +  = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" +  -- | This may fail if @sh@ has @Nothing@s in it. -shxFromSSX' :: StaticShX sh -> Maybe (ShX sh i) -shxFromSSX' ZKX = Just ZSX -shxFromSSX' (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX' sh -shxFromSSX' (SUnknown _ :!% _) = Nothing +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing  shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i  shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 3edebf6..8b670e5 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -131,6 +131,10 @@ listrLast (_ ::: sh@(_ ::: _)) = listrLast sh  listrLast (n ::: ZR) = n  listrLast ZR = error "unreachable" +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" +  listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i  listrIndex SZ (x ::: _) = x  listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs @@ -230,6 +234,10 @@ ixrInit (IxR list) = IxR (listrInit list)  ixrLast :: IxR (n + 1) i -> i  ixrLast (IxR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) +  ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i  ixrAppend = coerce (listrAppend @_ @i) @@ -310,6 +318,10 @@ shrInit (ShR list) = ShR (listrInit list)  shrLast :: ShR (n + 1) i -> i  shrLast (ShR list) = listrLast list +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) +  shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i  shrAppend = coerce (listrAppend @_ @i) @@ -347,3 +359,11 @@ instance KnownNat n => IsList (ShR n i) where    type Item (ShR n i) = i    fromList = ShR . IsList.fromList    toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index a24a91a..879e6b5 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -249,12 +249,7 @@ sshape (Shaped arr) = shsFromShX (mshape arr)  shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh  shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS  shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = -  castWith (subst1 (lem Refl)) $ +  castWith (subst1 (sym (lemMapJustCons Refl))) $      n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))                                     idx) -  where -    lem :: forall sh1 sh' n. -           Just n : sh1 :~: MapJust sh' -        -> n : Tail sh' :~: sh' -    lem Refl = unsafeCoerceRefl  shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index d34f3ec..ab16422 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -248,6 +248,12 @@ ixsInit (IxS list) = IxS (listsInit list)  ixsLast :: IxS (n : sh) i -> i  ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" +  ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i  ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -374,6 +380,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh  shsOrthotopeShape ZSS = Dict  shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l +  -- | Untyped: length is checked at runtime.  instance KnownShS sh => IsList (ListS sh (Const i)) where diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index b8a9aea..df466cf 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -31,6 +31,7 @@ module Data.Array.Nested.Types (    Replicate,    lemReplicateSucc,    MapJust, +  lemMapJustEmpty, lemMapJustCons,    Head,    Tail,    Init, @@ -116,6 +117,12 @@ type family MapJust l = r | r -> l where    MapJust '[] = '[]    MapJust (x : xs) = Just x : MapJust xs +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl +  type family Head l where    Head (x : _) = x | 
