From a3299c09e0fd12cf73c4a0a9a2ae37b8f69f9b10 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 11 Dec 2024 19:56:28 +0100 Subject: Simpler API to mcast --- src/Data/Array/Mixed/Shape.hs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'src/Data/Array/Mixed') diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index b5a4cb9..e5f8b67 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -378,6 +378,20 @@ shxEnum = \sh -> go sh id [] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] +shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') +shxCast ZSX ZKX = Just ZSX +shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh +shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh +shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh +shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast _ _ = Nothing + +-- | Partial version of 'shxCast'. +shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' +shxCast' sh ssh = case shxCast sh ssh of + Just sh' -> sh' + Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" + -- * Static mixed shapes -- cgit v1.2.3-70-g09d2