diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 73 |
1 files changed, 37 insertions, 36 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index b1b4f81..3f4ee9a 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -2,6 +2,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} @@ -297,30 +298,30 @@ ixxToLinear = \sh i -> go sh i 0 -- * Mixed shapes -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing {-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -328,7 +329,7 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) @@ -337,7 +338,7 @@ smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) -- | This is a newtype over 'ListX'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) +newtype ShX sh i = ShX (ListX sh (SMayNat i)) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i @@ -346,7 +347,7 @@ pattern ZSX = ShX ZX pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i + => SMayNat i n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) where i :$% ShX shl = ShX (i ::% shl) infixr 3 :$% @@ -447,35 +448,35 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxAppend = coerce (listxAppend @_ @(SMayNat i)) -shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead :: ShX (n : sh) i -> SMayNat i n shxHead (ShX list) = listxHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) +shxDropSSX = coerce (listxDrop @(SMayNat i) @(SMayNat ())) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +shxDropIx = coerce (listxDrop @(SMayNat i) @(Const j)) shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxDropSh = coerce (listxDrop @(SMayNat i) @(SMayNat i)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxInit = coerce (listxInit @(SMayNat i)) -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i)) shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh {-# INLINE shxZipWith #-} -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -525,7 +526,7 @@ shxCast' ssh sh = case shxCast ssh sh of -- | The part of a shape that is statically known. (A newtype over 'ListX'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) +newtype StaticShX sh = StaticShX (ListX sh (SMayNat ())) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh @@ -534,7 +535,7 @@ pattern ZKX = StaticShX ZX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 + => SMayNat () n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) where i :!% StaticShX shl = StaticShX (i ::% shl) infixr 3 :!% @@ -570,26 +571,26 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead :: StaticShX (n : sh) -> SMayNat () n ssxHead (StaticShX list) = listxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) +ssxDropSSX = coerce (listxDrop @(SMayNat ()) @(SMayNat ())) ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropIx = coerce (listxDrop @(SMayNat ()) @(Const i)) ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) +ssxDropSh = coerce (listxDrop @(SMayNat ()) @(SMayNat i)) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxInit = coerce (listxInit @(SMayNat ())) -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat ())) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX @@ -632,18 +633,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh |
