diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 166 |
1 files changed, 111 insertions, 55 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index bfc6ad2..a9bfe14 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,10 +1,9 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} @@ -39,7 +38,6 @@ import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) -import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits @@ -136,6 +134,17 @@ listsFromList topsh topl = go topsh topl ++ show (shsLength topsh) ++ ", list has length " ++ show (length topl) ++ ")" +{-# INLINEABLE listsFromListS #-} +listsFromListS :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) +listsFromListS topl0 topl = go topl0 topl + where + go :: ListS sh (Const i0) -> [i] -> ListS sh (Const i) + go ZS [] = ZS + go (_ ::$ l0) (i : is) = Const i ::$ go l0 is + go _ _ = error $ "listsFromListS: Mismatched list length (the model says " + ++ show (listsLength topl0) ++ ", list has length " + ++ show (length topl) ++ ")" + {-# INLINEABLE listsToList #-} listsToList :: ListS sh (Const i) -> [i] listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> @@ -185,16 +194,16 @@ listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of + case listsIndex i sh of item -> item ::$ listsPermute is sh --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> f (Index i sh) -listsIndex _ _ SZ (n ::$ _) = n -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) +-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed +listsIndex :: forall f i sh. SNat i -> ListS sh f -> f (Index i sh) +listsIndex SZ (n ::$ _) = n +listsIndex (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" + = listsIndex i sh +listsIndex _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) @@ -205,7 +214,7 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, Generic) + deriving (Eq, Ord, NFData) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -247,8 +256,6 @@ instance Foldable (IxS sh) where null ZIS = False null _ = True -instance NFData i => NFData (IxS sh i) - ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -258,6 +265,10 @@ ixsRank (IxS l) = listsRank l ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i ixsFromList = coerce (listsFromList @_ @i) +{-# INLINEABLE ixsFromIxS #-} +ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS = coerce (listsFromListS @_ @i0 @i) + {-# INLINEABLE ixsToList #-} ixsToList :: forall sh i. IxS sh i -> [i] ixsToList = coerce (listsToList @_ @i) @@ -278,11 +289,9 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i -ixsCast ZSS ZIS = ZIS -ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx -ixsCast _ _ = error "ixsCast: ranks don't match" +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -318,21 +327,34 @@ ixsToLinear = \sh i -> go sh i 0 -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Generic) +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS +pattern ZSS <- ShS (matchZSX -> Just Refl) + where ZSS = ShS ZSX + +matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZSX _ = Nothing pattern (:$$) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) +pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl)) + where i :$$ ShS shl = ShS (SKnown i :$% shl) + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) + | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing infixr 3 :$$ @@ -342,15 +364,13 @@ infixr 3 :$$ deriving instance Show (ShS sh) #else instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + showsPrec d (ShS shx) = showsPrec d shx #endif -instance NFData (ShS sh) where - rnf (ShS ZS) = () - rnf (ShS (SNat ::$ l)) = rnf (ShS l) - instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) @@ -358,17 +378,20 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l +shsLength (ShS shx) = shxLength shx -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Rank (MapJust sh) :~: Rank sh) $ + shxRank shx shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh +shsSize (ShS sh) = shxSize sh -- | This is a partial @const@ that fails when the second argument --- doesn't match the first. +-- doesn't match the first. It also has a better error message comparing +-- to just coercing 'shxFromList'. shsFromList :: ShS sh -> [Int] -> ShS sh shsFromList topsh topl = go topsh topl `seq` topsh where @@ -384,38 +407,71 @@ shsFromList topsh topl = go topsh topl `seq` topsh {-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] -shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> - let go :: ShS sh -> is - go ZSS = nil - go (sn :$$ sh) = fromSNat' sn `cons` go sh - in go topsh) +shsToList = coerce shxToList shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @_ @_ @Int) -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @_ @Int) + +shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce shxTakeLen -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) +shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLen = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce shxDropLen -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce shxPermute -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (listsIndex @SNat pis pshT i (coerce sh)) +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i sh = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex i (coerce sh) of + SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 |
