aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs27
1 files changed, 27 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index df506d6..33d9f56 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -10,6 +10,7 @@
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -24,6 +25,7 @@ module Data.Array.Mixed where
import qualified Data.Array.RankedS as S
import qualified Data.Array.Ranked as ORB
+import Data.Bifunctor (first)
import Data.Coerce
import Data.Functor.Const
import Data.Kind
@@ -90,6 +92,7 @@ type family Replicate n a where
Replicate n a = a : Replicate (n - 1) a
+type role ListX nominal representational
type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
data ListX sh f where
ZX :: ListX '[] f
@@ -114,6 +117,7 @@ foldListX _ ZX = mempty
foldListX f (x ::% xs) = f x <> foldListX f xs
+type role IxX nominal representational
type IxX :: [Maybe Nat] -> Type -> Type
newtype IxX sh i = IxX (ListX sh (Const i))
deriving (Show, Eq, Ord)
@@ -154,6 +158,7 @@ fromSMayNat _ g (SKnown s) = g s
fromSMayNat' :: SMayNat Int SNat n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
+type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
deriving (Show, Eq, Ord)
@@ -249,6 +254,10 @@ shTail (_ :$% sh) = sh
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
+shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
+shAppSplit _ ZKX idx = (ZSX, idx)
+shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx)
+
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
@@ -354,6 +363,24 @@ toVector (XArray arr) = S.toVector arr
scalar :: Storable a => a -> XArray '[] a
scalar = XArray . S.scalar
+eqShX :: IShX sh1 -> IShX sh2 -> Bool
+eqShX ZSX ZSX = True
+eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2
+eqShX _ _ = False
+
+-- | Will throw if the array does not have the casted-to shape.
+cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> StaticShX sh'
+ -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+cast ssh1 sh2 ssh' (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh'
+ , Refl <- lemRankApp (staticShapeFrom sh2) ssh'
+ = let arrsh :: IShX sh1
+ (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ in if eqShX arrsh sh2
+ then XArray arr
+ else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
+
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a