{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Mixed.Shape ( module Data.Array.Nested.Mixed.Shape, Rank, ) where import Control.DeepSeq (NFData(..)) import Control.Exception (assert) import Data.Bifunctor (first) import Data.Coerce import Data.Foldable qualified as Foldable import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits #if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) import GHC.TypeLits.Orphans () #endif import Data.Array.Nested.Mixed.ListX import Data.Array.Nested.Types -- * Mixed indices -- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh i) deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i pattern i :.% l <- IxX (i ::% (IxX -> l)) where i :.% IxX l = IxX (i ::% l) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type IIxX sh = IxX sh Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxX sh i) #else instance Show i => Show (IxX sh i) where showsPrec _ (IxX l) = listxShow shows l #endif {-# INLINE ixxFromList #-} ixxFromList :: StaticShX sh -> [i] -> IxX sh i ixxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l ixxRank :: IxX sh i -> SNat (Rank sh) ixxRank ZIX = SNat ixxRank (_ :.% l) | SNat <- ixxRank l = SNat ixxZero :: StaticShX sh -> IIxX sh ixxZero ZKX = ZIX ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh ixxHead :: IxX (mn ': sh) i -> i ixxHead (i :.% _) = i ixxTail :: IxX (n : sh) i -> IxX sh i ixxTail (_ :.% sh) = sh ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixxAppend (IxX l1) (IxX l2) = IxX $ lazilyConcat (++) l1 l2 ixxDrop :: forall i j sh sh'. IxX sh j -> IxX (sh ++ sh') i -> IxX sh' i ixxDrop ZIX long = long ixxDrop (_ :.% short) long = case long of _ :.% long' -> ixxDrop short long' ixxInit :: forall i n sh. IxX (n : sh) i -> IxX (Init (n : sh)) i ixxInit (i :.% sh@(_ :.% _)) = i :.% ixxInit sh ixxInit (_ :.% ZIX) = ZIX ixxLast :: forall i n sh. IxX (n : sh) i -> i ixxLast (_ :.% sh@(_ :.% _)) = ixxLast sh ixxLast (x :.% ZIX) = x ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js {-# INLINE ixxZipWith #-} ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i ixxCast ZKX ZIX = ZIX ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx ixxCast _ _ = error "ixxCast: ranks don't match" -- | Given a multidimensional index, get the corresponding linear -- index into the buffer. {-# INLINEABLE ixxToLinear #-} ixxToLinear :: Num i => IShX sh -> IxX sh i -> i ixxToLinear = \sh i -> go sh i 0 where go :: Num i => IShX sh -> IxX sh i -> i -> i go ZSX ZIX !a = a go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) {-# INLINEABLE ixxFromLinear #-} ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) in fromLin0 sh suffixes where -- Unfold first iteration of fromLin to do the range check. -- Don't inline this function at first to allow GHC to inline the outer -- function and realise that 'suffixes' is shared. But then later inline it -- anyway, to avoid the function call. Removing the pragma makes GHC -- somehow unable to recognise that 'suffixes' can be shared in a loop. {-# NOINLINE [0] fromLin0 #-} fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i fromLin0 sh suffixes i = if i < 0 then outrange sh i else case (sh, suffixes) of (ZSX, _) | i > 0 -> outrange sh i | otherwise -> ZIX ((fromSMayNat' -> n) :$% sh', suff : suffs) -> let (q, r) = i `quotRem` suff in if q >= n then outrange sh i else fromIntegral q :.% fromLin sh' suffs r _ -> error "impossible" fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i fromLin ZSX _ !_ = ZIX fromLin (_ :$% sh') (suff : suffs) i = let (q, r) = i `quotRem` suff -- suff == shrSize sh' in fromIntegral q :.% fromLin sh' suffs r fromLin _ _ _ = error "impossible" {-# NOINLINE outrange #-} outrange :: IShX sh -> Int -> a outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" shxEnum :: IShX sh -> [IIxX sh] shxEnum = shxEnum' {-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site shxEnum' :: Num i => IShX sh -> [IxX sh i] shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] where suffixes = drop 1 (scanr (*) 1 (shxToList sh)) fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i fromLin ZSX _ _ = ZIX fromLin (_ :$% sh') (I# suff# : suffs) i# = let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' in fromIntegral (I# q#) :.% fromLin sh' suffs r# fromLin _ _ _ = error "impossible" -- * Mixed shapes data SMayNat i n where SUnknown :: i -> SMayNat i Nothing SKnown :: SNat n -> SMayNat i (Just n) deriving instance Show i => Show (SMayNat i n) deriving instance Eq i => Eq (SMayNat i n) deriving instance Ord i => Ord (SMayNat i n) instance NFData i => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i rnf (SKnown SNat) = () instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing {-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => SNat m -> r) -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s {-# INLINE fromSMayNat' #-} fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where AddMaybe Nothing _ = Nothing AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type data ShX sh i where ZSX :: ShX '[] i ConsUnknown :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i -- TODO: bring this UNPACK back when GHC no longer crashes: -- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ShX sh i -> ShX (Just n : sh) i ConsKnown :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i deriving instance Ord i => Ord (ShX sh i) -- A manually defined instance and this INLINEABLE is needed to specialize -- mdot1Inner (otherwise GHC warns specialization breaks down here). instance Eq i => Eq (ShX sh i) where {-# INLINEABLE (==) #-} ZSX == ZSX = True ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2 ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2 #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ShX sh i) #else instance Show i => Show (ShX sh i) where showsPrec _ l = shxShow (fromSMayNat shows (shows . fromSNat)) l #endif instance NFData i => NFData (ShX sh i) where rnf ZSX = () rnf (x `ConsUnknown` l) = rnf x `seq` rnf l rnf (SNat `ConsKnown` l) = rnf l instance Functor (ShX sh) where {-# INLINE fmap #-} fmap f l = shxFmap (fromSMayNat (SUnknown . f) SKnown) l data UnconsShXRes i sh1 = forall n sh. (n : sh ~ sh1) => UnconsShXRes (SMayNat i n) (ShX sh i) shxUncons :: ShX sh1 i -> Maybe (UnconsShXRes i sh1) shxUncons (i `ConsUnknown` shl') = Just (UnconsShXRes (SUnknown i) shl') shxUncons (i `ConsKnown` shl') = Just (UnconsShXRes (SKnown i) shl') shxUncons ZSX = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to -- 'testEquality', except on the penultimate type parameter. shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqType ZSX ZSX = Just Refl shxEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') | Just Refl <- shxEqType sh sh' = Just Refl shxEqType (n `ConsKnown` sh) (m `ConsKnown` sh') | Just Refl <- testEquality n m , Just Refl <- shxEqType sh sh' = Just Refl shxEqType _ _ = Nothing -- | This checks whether the two lists actually contain equal values. This is -- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ -- in the @some@ package (except on the penultimate type parameter). shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqual ZSX ZSX = Just Refl shxEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') | n == m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual (n `ConsKnown` sh) (m `ConsKnown` sh') | Just Refl <- testEquality n m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual _ _ = Nothing {-# INLINE shxFmap #-} shxFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ShX sh i -> ShX sh j shxFmap _ ZSX = ZSX shxFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of SUnknown y -> y `ConsUnknown` shxFmap f xs shxFmap f (x `ConsKnown` xs) = case f (SKnown x) of SKnown y -> y `ConsKnown` shxFmap f xs {-# INLINE shxFoldMap #-} shxFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ShX sh i -> m shxFoldMap _ ZSX = mempty shxFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> shxFoldMap f xs shxFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> shxFoldMap f xs shxLength :: ShX sh i -> Int shxLength = getSum . shxFoldMap (\_ -> Sum 1) shxRank :: ShX sh i -> SNat (Rank sh) shxRank ZSX = SNat shxRank (_ `ConsUnknown` l) | SNat <- shxRank l = SNat shxRank (_ `ConsKnown` l) | SNat <- shxRank l = SNat {-# INLINE shxShow #-} shxShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ShX sh i -> ShowS shxShow f l = showString "[" . go "" l . showString "]" where go :: String -> ShX sh' i -> ShowS go _ ZSX = id go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs shxHead :: ShX (mn ': sh) i -> SMayNat i mn shxHead (i `ConsUnknown` _) = SUnknown i shxHead (i `ConsKnown` _) = SKnown i shxTail :: ShX (n : sh) i -> ShX sh i shxTail (_ `ConsUnknown` sh) = sh shxTail (_ `ConsKnown` sh) = sh shxAppend :: ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend ZSX idx' = idx' shxAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` shxAppend idx idx' shxAppend (i `ConsKnown` idx) idx' = i `ConsKnown` shxAppend idx idx' shxDropSh :: forall sh sh' i j. ShX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropSh ZSX long = long shxDropSh (_ `ConsUnknown` short) long = case long of _ `ConsUnknown` long' -> shxDropSh short long' shxDropSh (_ `ConsKnown` short) long = case long of _ `ConsKnown` long' -> shxDropSh short long' shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (shxDropSh @_ @_ @i @()) shxInit :: forall i n sh. ShX (n : sh) i -> ShX (Init (n : sh)) i shxInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` shxInit sh shxInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` shxInit sh shxInit (_ `ConsUnknown` ZSX) = ZSX shxInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` shxInit sh shxInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` shxInit sh shxInit (_ `ConsKnown` ZSX) = ZSX shxLast :: forall i n sh. ShX (n : sh) i -> SMayNat i (Last (n : sh)) shxLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = shxLast sh shxLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = shxLast sh shxLast (x `ConsUnknown` ZSX) = SUnknown x shxLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = shxLast sh shxLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = shxLast sh shxLast (x `ConsKnown` ZSX) = SKnown x pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => SMayNat i n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- (shxUncons -> Just (UnconsShXRes i shl)) where i :$% shl = case i of; SUnknown x -> x `ConsUnknown` shl; SKnown x -> x `ConsKnown` shl infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -- We don't report the size of the list in case of errors in order not to retain the list. {-# INLINEABLE shxFromList #-} shxFromList :: StaticShX sh -> [Int] -> IShX sh shxFromList (StaticShX topssh) topl = go topssh topl where go :: ShX sh' () -> [Int] -> ShX sh' Int go ZSX [] = ZSX go ZSX _ = error $ "shxFromList: List too long (type says " ++ show (shxLength topssh) ++ ")" go (ConsKnown sn sh) (i : is) | i == fromSNat' sn = ConsKnown sn (go sh is) | otherwise = error "shxFromList: Value does not match typing" go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is) go _ _ = error $ "shxFromList: List too short (type says " ++ show (shxLength topssh) ++ ")" {-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] shxToList l = build (\(cons :: i -> is -> is) (nil :: is) -> let go :: ShX sh Int -> is go ZSX = nil go (ConsUnknown i rest) = i `cons` go rest go (ConsKnown sn rest) = fromSNat' sn `cons` go rest in go l) -- If it ever matters for performance, this is unsafeCoercible. shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i shxFromSSX ZKX = ZSX shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) | Refl <- lemMapJustCons @sh Refl = SKnown n :$% shxFromSSX sh shxFromSSX (SUnknown _ :!% _) = error "unreachable" -- | This may fail if @sh@ has @Nothing@s in it. shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) shxFromSSX2 ZKX = Just ZSX shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i shxTakeSh _ ZSX _ = ZSX shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh {-# INLINEABLE shxTakeIx #-} shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i shxTakeIx _ (IxX ZX) _ = ZSX shxTakeIx proxy (IxX (_ ::% long)) short = case short of i :$% short' -> i :$% shxTakeIx proxy (IxX long) short' {-# INLINEABLE shxDropIx #-} shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx ZIX long = long shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' {-# INLINE shxZipWith #-} shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js -- This is a weird operation, so it has a long name shxCompleteZeros :: StaticShX sh -> IShX sh shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') shxCast ZKX ZSX = Just ZSX shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" -- * Static mixed shapes -- | The part of a shape that is statically known. (A newtype over 'ShX'.) type StaticShX :: [Maybe Nat] -> Type newtype StaticShX sh = StaticShX (ShX sh ()) deriving (NFData) instance Eq (StaticShX sh) where _ == _ = True instance Ord (StaticShX sh) where compare _ _ = EQ pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh pattern ZKX = StaticShX ZSX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SMayNat () n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (shxUncons -> Just (UnconsShXRes i (StaticShX -> shl))) where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show (StaticShX sh) #else instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = shxShow (fromSMayNat shows (shows . fromSNat)) l #endif instance TestEquality StaticShX where testEquality (StaticShX l1) (StaticShX l2) = shxEqType l1 l2 ssxLength :: StaticShX sh -> Int ssxLength (StaticShX l) = shxLength l ssxRank :: StaticShX sh -> SNat (Rank sh) ssxRank (StaticShX l) = shxRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend = coerce (shxAppend @_ @()) ssxHead :: StaticShX (n : sh) -> SMayNat () n ssxHead (StaticShX list) = shxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (StaticShX list) = StaticShX (shxTail list) ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh ssxTakeIx _ (IxX ZX) _ = ZKX ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short' ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx (IxX ZX) long = long ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropSh = coerce (shxDropSh @_ @_ @() @i) ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropSSX = coerce (shxDropSh @_ @_ @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) ssxInit = coerce (shxInit @()) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) ssxLast = coerce (shxLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxReplicate n ssxIotaFrom :: StaticShX sh -> Int -> [Int] ssxIotaFrom ZKX _ = [] ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1) ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShX :: [Maybe Nat] -> Constraint class KnownShX sh where knownShX :: StaticShX sh instance KnownShX '[] where knownShX = ZKX instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r withKnownShX = withDict @(KnownShX sh) -- * Flattening type Flatten sh = Flatten' 1 sh type family Flatten' acc sh where Flatten' acc '[] = Just acc Flatten' acc (Nothing : sh) = Nothing Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh goUnknown :: Int -> IShX sh -> Int goUnknown acc ZSX = acc goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . IsList.fromList toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). instance KnownShX sh => IsList (IShX sh) where type Item (IShX sh) = Int fromList = shxFromList (knownShX @sh) toList = shxToList