diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-18 13:24:32 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-18 13:24:32 +0200 |
commit | 2e28993ef478ff8c1eed549010383baf51ddec90 (patch) | |
tree | 7b9ee20fe2b17b5cbf7d3798f6b80d095257a24c /src/Data/Array/Mixed.hs | |
parent | 4adbbd8e2e635cc4c647be40f0dd258668dd2452 (diff) |
More WIP
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index df506d6..33d9f56 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -10,6 +10,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -24,6 +25,7 @@ module Data.Array.Mixed where import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB +import Data.Bifunctor (first) import Data.Coerce import Data.Functor.Const import Data.Kind @@ -90,6 +92,7 @@ type family Replicate n a where Replicate n a = a : Replicate (n - 1) a +type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type data ListX sh f where ZX :: ListX '[] f @@ -114,6 +117,7 @@ foldListX _ ZX = mempty foldListX f (x ::% xs) = f x <> foldListX f xs +type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) deriving (Show, Eq, Ord) @@ -154,6 +158,7 @@ fromSMayNat _ g (SKnown s) = g s fromSMayNat' :: SMayNat Int SNat n -> Int fromSMayNat' = fromSMayNat id fromSNat' +type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) deriving (Show, Eq, Ord) @@ -249,6 +254,10 @@ shTail (_ :$% sh) = sh ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh +shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') +shAppSplit _ ZKX idx = (ZSX, idx) +shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) + ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' @@ -354,6 +363,24 @@ toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar +eqShX :: IShX sh1 -> IShX sh2 -> Bool +eqShX ZSX ZSX = True +eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2 +eqShX _ _ = False + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> StaticShX sh' + -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) + | Refl <- lemRankApp ssh1 ssh' + , Refl <- lemRankApp (staticShapeFrom sh2) ssh' + = let arrsh :: IShX sh1 + (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in if eqShX arrsh sh2 + then XArray arr + else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a |