aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs89
-rw-r--r--src/Data/Array/Nested/Permutation.hs20
2 files changed, 73 insertions, 36 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index c8c9a7b..dc4063c 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -301,7 +301,7 @@ ixxToLinear = \sh i -> go sh i 0
data SMayNat i n where
SUnknown :: i -> SMayNat i Nothing
- SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
+ SKnown :: SNat n -> SMayNat i (Just n)
deriving instance Show i => Show (SMayNat i n)
deriving instance Eq i => Eq (SMayNat i n)
deriving instance Ord i => Ord (SMayNat i n)
@@ -340,10 +340,10 @@ type role ListH nominal representational
type ListH :: [Maybe Nat] -> Type -> Type
data ListH sh i where
ZH :: ListH '[] i
- (::#) :: forall n sh i. SMayNat i n -> ListH sh i -> ListH (n : sh) i
+ ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i
+ ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i
deriving instance Eq i => Eq (ListH sh i)
deriving instance Ord i => Ord (ListH sh i)
-infixr 3 ::#
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
deriving instance Show i => Show (ListH sh i)
@@ -352,14 +352,16 @@ instance Show i => Show (ListH sh i) where
showsPrec _ = listhShow shows
#endif
-instance (forall n. NFData (SMayNat i n)) => NFData (ListH sh i) where
+instance NFData i => NFData (ListH sh i) where
rnf ZH = ()
- rnf (x ::# l) = rnf x `seq` rnf l
+ rnf (x `ConsUnknown` l) = rnf x `seq` rnf l
+ rnf (SNat `ConsKnown` l) = rnf l
data UnconsListHRes i sh1 =
forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n)
listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1)
-listhUncons (i ::# shl') = Just (UnconsListHRes shl' i)
+listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes shl' (SUnknown i))
+listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes shl' (SKnown i))
listhUncons ZH = Nothing
-- | This checks only whether the types are equal; if the elements of the list
@@ -367,7 +369,10 @@ listhUncons ZH = Nothing
-- 'testEquality', except on the penultimate type parameter.
listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
listhEqType ZH ZH = Just Refl
-listhEqType (n ::# sh) (m ::# sh')
+listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh')
+ | Just Refl <- listhEqType sh sh'
+ = Just Refl
+listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh')
| Just Refl <- testEquality n m
, Just Refl <- listhEqType sh sh'
= Just Refl
@@ -378,9 +383,12 @@ listhEqType _ _ = Nothing
-- in the @some@ package (except on the penultimate type parameter).
listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
listhEqual ZH ZH = Just Refl
-listhEqual (n ::# sh) (m ::# sh')
+listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh')
+ | n == m
+ , Just Refl <- listhEqual sh sh'
+ = Just Refl
+listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh')
| Just Refl <- testEquality n m
- , n == m
, Just Refl <- listhEqual sh sh'
= Just Refl
listhEqual _ _ = Nothing
@@ -388,19 +396,24 @@ listhEqual _ _ = Nothing
{-# INLINE listhFmap #-}
listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j
listhFmap _ ZH = ZH
-listhFmap f (x ::# xs) = f x ::# listhFmap f xs
+listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of
+ SUnknown y -> y `ConsUnknown` listhFmap f xs
+listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of
+ SKnown y -> y `ConsKnown` listhFmap f xs
{-# INLINE listhFoldMap #-}
listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m
listhFoldMap _ ZH = mempty
-listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs
+listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs
+listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs
listhLength :: ListH sh i -> Int
listhLength = getSum . listhFoldMap (\_ -> Sum 1)
listhRank :: ListH sh i -> SNat (Rank sh)
listhRank ZH = SNat
-listhRank (_ ::# l) | SNat <- listhRank l = SNat
+listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat
+listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat
{-# INLINE listhShow #-}
listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS
@@ -408,29 +421,44 @@ listhShow f l = showString "[" . go "" l . showString "]"
where
go :: String -> ListH sh' i -> ShowS
go _ ZH = id
- go prefix (x ::# xs) = showString prefix . f x . go "," xs
+ go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs
+ go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs
listhHead :: ListH (mn ': sh) i -> SMayNat i mn
-listhHead (i ::# _) = i
+listhHead (i `ConsUnknown` _) = SUnknown i
+listhHead (i `ConsKnown` _) = SKnown i
listhTail :: ListH (n : sh) i -> ListH sh i
-listhTail (_ ::# sh) = sh
+listhTail (_ `ConsUnknown` sh) = sh
+listhTail (_ `ConsKnown` sh) = sh
listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i
listhAppend ZH idx' = idx'
-listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx'
+listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx'
+listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx'
listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i
listhDrop ZH long = long
-listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long'
+listhDrop (_ `ConsUnknown` short) long = case long of
+ _ `ConsUnknown` long' -> listhDrop short long'
+listhDrop (_ `ConsKnown` short) long = case long of
+ _ `ConsKnown` long' -> listhDrop short long'
listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i
-listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh
-listhInit (_ ::# ZH) = ZH
+listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh
+listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh
+listhInit (_ `ConsUnknown` ZH) = ZH
+listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh
+listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh
+listhInit (_ `ConsKnown` ZH) = ZH
listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh))
-listhLast (_ ::# sh@(_ ::# _)) = listhLast sh
-listhLast (x ::# ZH) = x
+listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh
+listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh
+listhLast (x `ConsUnknown` ZH) = SUnknown x
+listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh
+listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh
+listhLast (x `ConsKnown` ZH) = SKnown x
-- * Mixed shapes
@@ -448,7 +476,7 @@ pattern (:$%)
forall n sh. (n : sh ~ sh1)
=> SMayNat i n -> ShX sh i -> ShX sh1 i
pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i))
- where i :$% ShX shl = ShX (i ::# shl)
+ where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl)
infixr 3 :$%
{-# COMPLETE ZSX, (:$%) #-}
@@ -468,8 +496,8 @@ instance Functor (ShX sh) where
instance NFData i => NFData (ShX sh i) where
rnf (ShX ZH) = ()
- rnf (ShX (SUnknown i ::# l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::# l)) = rnf (ShX l)
+ rnf (ShX (i `ConsUnknown` l)) = rnf i `seq` rnf (ShX l)
+ rnf (ShX (SNat `ConsKnown` l)) = rnf (ShX l)
-- | This checks only whether the types are equal; unknown dimensions might
-- still differ. This corresponds to 'testEquality', except on the penultimate
@@ -561,7 +589,7 @@ shxDropSSX = coerce (listhDrop @i @())
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx (IxX ZX) long = long
-shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long')
+shxDropIx (IxX (_ ::% short)) long = case long of _ :$% long' -> shxDropIx (IxX short) long'
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSh = coerce (listhDrop @i @i)
@@ -638,7 +666,8 @@ pattern (:!%)
forall n sh. (n : sh ~ sh1)
=> SMayNat () n -> StaticShX sh -> StaticShX sh1
pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i))
- where i :!% StaticShX shl = StaticShX (i ::# shl)
+ where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl)
+
infixr 3 :!%
{-# COMPLETE ZKX, (:!%) #-}
@@ -652,8 +681,8 @@ instance Show (StaticShX sh) where
instance NFData (StaticShX sh) where
rnf (StaticShX ZH) = ()
- rnf (StaticShX (SUnknown () ::# l)) = rnf (StaticShX l)
- rnf (StaticShX (SKnown SNat ::# l)) = rnf (StaticShX l)
+ rnf (StaticShX (() `ConsUnknown` l)) = rnf (StaticShX l)
+ rnf (StaticShX (SNat `ConsKnown` l)) = rnf (StaticShX l)
instance TestEquality StaticShX where
testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2
@@ -680,11 +709,11 @@ ssxTail (_ :!% ssh) = ssh
ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
ssxTakeIx _ (IxX ZX) _ = ZKX
-ssxTakeIx proxy (IxX (_ ::% long)) short = case short of StaticShX (i ::# short') -> i :!% ssxTakeIx proxy (IxX long) (StaticShX short')
+ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short'
ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropIx (IxX ZX) long = long
-ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long')
+ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long'
ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropSh = coerce (listhDrop @() @i)
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 93c46ed..2e0c1ca 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -174,23 +174,31 @@ type family DropLen ref l where
listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i
listhTakeLen PNil _ = ZH
-listhTakeLen (_ `PCons` is) (n ::# sh) = n ::# listhTakeLen is sh
+listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh
+listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh
listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape"
listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i
listhDropLen PNil sh = sh
-listhDropLen (_ `PCons` is) (_ ::# sh) = listhDropLen is sh
+listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh
+listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh
listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape"
listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i
listhPermute PNil _ = ZH
listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) =
- listhIndex i sh ::# listhPermute is sh
+ case listhIndex i sh of
+ SUnknown x -> x `ConsUnknown` listhPermute is sh
+ SKnown x -> x `ConsKnown` listhPermute is sh
listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh)
-listhIndex SZ (n ::# _) = n
-listhIndex (SS (i :: SNat k')) ((_ :: SMayNat i n) ::# (sh :: ListH sh' i))
- | Refl <- lemIndexSucc (Proxy @k') (Proxy @n) (Proxy @sh')
+listhIndex SZ (n `ConsUnknown` _) = SUnknown n
+listhIndex SZ (n `ConsKnown` _) = SKnown n
+listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh')
+ = listhIndex i sh
+listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh')
= listhIndex i sh
listhIndex _ ZH = error "Index into empty shape"