{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where import Control.DeepSeq (NFData(..)) import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB import Data.Bifunctor (first) import Data.Coerce import qualified Data.Foldable as Foldable import Data.Functor.Const import Data.Kind import Data.List (sort) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.IsList (IsList) import qualified GHC.IsList as IsList import GHC.TypeError import GHC.TypeLits import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Internal.Arith -- | Evidence for the constraint @c a@. data Dict c a where Dict :: c a => Dict c a fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat pattern SZ :: () => (n ~ 0) => SNat n pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) where SZ = SNat pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) where SS = snatSucc {-# COMPLETE SZ, SS #-} snatSucc :: SNat n -> SNat (n + 1) snatSucc SNat = SNat data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) snatPred snp1 = withKnownNat snp1 $ case cmpNat (Proxy @1) (Proxy @np1) of LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) GTI -> Nothing -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 lemAppNil :: l ++ '[] :~: l lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type data ListX sh f where ZX :: ListX '[] f (::%) :: f n -> ListX sh f -> ListX (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) infixr 3 ::% instance (forall n. Show (f n)) => Show (ListX sh f) where showsPrec _ = showListX shows instance (forall n. NFData (f n)) => NFData (ListX sh f) where rnf ZX = () rnf (x ::% l) = rnf x `seq` rnf l data UnconsListXRes f sh1 = forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) unconsListX ZX = Nothing fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g fmapListX _ ZX = ZX fmapListX f (x ::% xs) = f x ::% fmapListX f xs foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m foldListX _ ZX = mempty foldListX f (x ::% xs) = f x <> foldListX f xs lengthListX :: ListX sh f -> Int lengthListX = getSum . foldListX (\_ -> Sum 1) snatLengthListX :: ListX sh f -> SNat (Rank sh) snatLengthListX ZX = SNat snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS showListX f l = showString "[" . go "" l . showString "]" where go :: String -> ListX sh' f -> ShowS go _ ZX = id go prefix (x ::% xs) = showString prefix . f x . go "," xs listXToList :: ListX sh' (Const i) -> [i] listXToList ZX = [] listXToList (Const i ::% is) = i : listXToList is type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) deriving (Eq, Ord, Generic) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) where i :.% IxX shl = IxX (Const i ::% shl) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} type IIxX sh = IxX sh Int instance Show i => Show (IxX sh i) where showsPrec _ (IxX l) = showListX (\(Const i) -> shows i) l instance Functor (IxX sh) where fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) instance Foldable (IxX sh) where foldMap f (IxX l) = foldListX (f . getConst) l instance NFData i => NFData (IxX sh i) data SMayNat i f n where SUnknown :: i -> SMayNat i f Nothing SKnown :: f n -> SMayNat i f (Just n) deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s fromSMayNat' :: SMayNat Int SNat n -> Int fromSMayNat' = fromSMayNat id fromSNat' type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i pattern ZSX = ShX ZX pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => SMayNat i SNat n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) where i :$% ShX shl = ShX (i ::% shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int instance Show i => Show (ShX sh i) where showsPrec _ (ShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l instance Functor (ShX sh) where fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) instance NFData i => NFData (ShX sh i) where rnf (ShX ZX) = () rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) lengthShX :: ShX sh i -> Int lengthShX (ShX l) = lengthListX l shXToList :: IShX sh -> [Int] shXToList ZSX = [] shXToList (smn :$% sh) = fromSMayNat' smn : shXToList sh -- | The part of a shape that is statically known. type StaticShX :: [Maybe Nat] -> Type newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh pattern ZKX = StaticShX ZX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) where i :!% StaticShX shl = StaticShX (i ::% shl) infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l lengthStaticShX :: StaticShX sh -> Int lengthStaticShX (StaticShX l) = lengthListX l geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') geqStaticShX ZKX ZKX = Just Refl geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') | Just Refl <- sameNat n m , Just Refl <- geqStaticShX sh sh' = Just Refl geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh') | Just Refl <- geqStaticShX sh sh' = Just Refl geqStaticShX _ _ = Nothing -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShX :: [Maybe Nat] -> Constraint class KnownShX sh where knownShX :: StaticShX sh instance KnownShX '[] where knownShX = ZKX instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX -- | Very untyped: only length is checked (at runtime). instance KnownShX sh => IsList (ListX sh (Const i)) where type Item (ListX sh (Const i)) = i fromList topl = go (knownShX @sh) topl where go :: StaticShX sh' -> [i] -> ListX sh' (Const i) go ZKX [] = ZX go (_ :!% sh) (i : is) = Const i ::% go sh is go _ _ = error $ "IsList(ListX): Mismatched list length (type says " ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = listXToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShX sh => IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . IsList.fromList toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). instance KnownShX sh => IsList (ShX sh Int) where type Item (ShX sh Int) = Int fromList topl = ShX (go (knownShX @sh) topl) where go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) go ZKX [] = ZX go (SKnown sn :!% sh) (i : is) | i == fromSNat' sn = SKnown sn ::% go sh is | otherwise = error $ "IsList(ShX): Value does not match typing (type says " ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is go _ _ = error $ "IsList(ShX): Mismatched list length (type says " ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " ++ show (length topl) ++ ")" toList = shXToList type family Rank sh where Rank '[] = 0 Rank (_ : sh) = Rank sh + 1 type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show, Eq, Generic) -- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. deriving instance (Ord a, Storable a) => Ord (XArray '[] a) instance NFData a => NFData (XArray sh a) zeroIxX :: StaticShX sh -> IIxX sh zeroIxX ZKX = ZIX zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh zeroIxX' :: IShX sh -> IIxX sh zeroIxX' ZSX = ZIX zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh -- This is a weird operation, so it has a long name completeShXzeros :: StaticShX sh -> IShX sh completeShXzeros ZKX = ZSX completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixAppend = coerce (listxAppend @_ @(Const i)) shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shAppend = coerce (listxAppend @_ @(SMayNat i SNat)) listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f listxDrop long ZX = long listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i ixDrop = coerce (listxDrop @(Const i) @(Const i)) shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i shTakeSSX _ = flip go where go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i go ZKX _ = ZSX go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) -- TODO: generalise all these things to arbitrary @i@ shTail :: IShX (n : sh) -> IShX sh shTail (_ :$% sh) = sh ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') shAppSplit _ ZKX idx = (ZSX, idx) shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' shapeSize :: IShX sh -> Int shapeSize ZSX = 1 shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. ssxToShape' :: StaticShX sh -> Maybe (IShX sh) ssxToShape' ZKX = Just ZSX ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh ssxToShape' (SUnknown _ :!% _) = Nothing lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' = SUnknown () :!% ssxReplicate n fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) go :: IShX sh -> Int -> (IIxX sh, Int) go ZSX i = (ZIX, i) go (n :$% sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` fromSMayNat' n in (locali :.% idx, upi) toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) go :: IShX sh -> IIxX sh -> (Int, Int) go ZSX ZIX = (0, 1) go (n :$% sh) (i :.% ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, fromSMayNat' n * sz) enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] shapeLshape :: IShX sh -> S.ShapeL shapeLshape ZSX = [] shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh ssxLength :: StaticShX sh -> Int ssxLength ZKX = 0 ssxLength (_ :!% ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKX = [] ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh type Flatten sh = Flatten' 1 sh type family Flatten' acc sh where Flatten' acc '[] = Just acc Flatten' acc (Nothing : sh) = Nothing Flatten' acc (Just n : sh) = Flatten' (acc * n) sh flattenSSX :: StaticShX sh -> SMayNat () SNat (Flatten sh) flattenSSX = go (SNat @1) where go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (mulSNat acc sn) sh flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) flattenSh = go (SNat @1) where go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh goUnknown :: Int -> IShX sh -> Int goUnknown acc ZSX = acc goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh staticShapeFrom :: IShX sh -> StaticShX sh staticShapeFrom ZSX = ZKX staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) lemKnownNatRank ZSX = Dict lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) lemKnownNatRankSSX ZKX = Dict lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' go ZKX [] = ZSX go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar eqShX :: IShX sh1 -> IShX sh2 -> Bool eqShX ZSX ZSX = True eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2 eqShX _ _ = False -- | Will throw if the array does not have the casted-to shape. cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> StaticShX sh' -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (staticShapeFrom sh2) ssh' = let arrsh :: IShX sh1 (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if eqShX arrsh sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') , Refl <- lemRankApp (staticShapeFrom sh) ssh' = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ arr) replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a replicateScal sh x | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' type family AddMaybe n m where AddMaybe Nothing _ = Nothing AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -- This should be a function in base plusSNat :: SNat n -> SNat m -> SNat (n + m) plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce -- This should be a function in base mulSNat :: SNat n -> SNat m -> SNat (n * m) mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX ssh = XArray (S.append a b) -- | If the prefix of the shape of the input array (@sh@) is empty (i.e. -- contains a zero), then there is no way to deduce the full shape of the output -- array (more precisely, the @sh2@ part): that could only come from calling -- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in -- this case; we choose to fill the shape with zeros wherever we cannot deduce -- what it should be. -- -- For example, if: -- -- @ -- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] -- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float -- @ -- -- then: -- -- @ -- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float -- @ -- -- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ -- in this shape: we don't know if @f@ intended to return an array with shape 0 -- here (it probably didn't), but there is no better number to put here absent -- a subarray of the input to pass to @f@. -- -- In this particular case the fact that @sh@ is empty was evident from the -- type-level information, but the same situation occurs when @sh@ consists of -- @Nothing@s, and some of those happen to be zero at runtime. rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) in if any (== 0) (shapeLshape sh) then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> let XArray r = f (XArray a) in r) arr) rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh -- | The caveat about empty arrays at @rerank@ applies here too. rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) in if any (== 0) (shapeLshape sh) then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> let XArray r = f (XArray a) (XArray b) in r) arr1 arr2) type family Elem x l where Elem x '[] = 'False Elem x (x : _) = 'True Elem x (_ : ys) = Elem x ys type family AllElem' as bs where AllElem' '[] bs = 'True AllElem' (a : as) bs = Elem a bs && AllElem' as bs type AllElem as bs = Assert (AllElem' as bs) (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) type family Count i n where Count n n = '[] Count i n = i : Count (i + 1) n type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) type family Index i sh where Index 0 (n : sh) = n Index i (_ : sh) = Index (i - 1) sh type family Permute is sh where Permute '[] sh = '[] Permute (i : is) sh = Index i sh : Permute is sh type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh data HList f list where HNil :: HList f '[] HCons :: f a -> HList f l -> HList f (a : l) infixr 5 `HCons` deriving instance (forall a. Show (f a)) => Show (HList f list) deriving instance (forall a. Eq (f a)) => Eq (HList f list) foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m foldHList _ HNil = mempty foldHList f (x `HCons` l) = f x <> foldHList f l snatLengthHList :: HList f list -> SNat (Rank list) snatLengthHList HNil = SNat snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is lemRankPermute _ HNil = Refl lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is lemRankDropLen ZKX HNil = Refl lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl lemRankDropLen (_ :!% _) HNil = Refl lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f listxTakeLen HNil _ = ZX listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape" listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f listxDropLen HNil sh = sh listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape" listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f listxPermute HNil _ = ZX listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f listxIndex _ _ SZ (n ::% _) rest = n ::% rest listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listxIndex p pT i sh rest listxIndex _ _ _ ZX _ = error "Index into empty shape" listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listxPermute @(SMayNat () SNat)) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) -- TODO: test this thing more properly invertPermutation :: HList SNat is -> (forall is'. Permutation is' => HList SNat is' -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) -> r) -> r invertPermutation = \perm k -> genPerm perm $ \(invperm :: HList SNat is') -> let sn = snatLengthHList invperm in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of (Just Refl, Just Refl) -> k invperm (\ssh -> case provePermInverse perm invperm ssh of Just eq -> eq Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm ++ " ; invperm = " ++ show invperm) _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm ++ " ; invperm = " ++ show invperm where genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r genPerm perm = let permList = foldHList (pure . fromSNat) perm in toHList $ map snd (sort (zip permList [0..])) where toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r toHList [] k = k HNil toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True lemElemCount _ _ = unsafeCoerce Refl lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n lemCount _ _ = unsafeCoerce Refl lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True lemElem _ _ = unsafeCoerce Refl provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is' -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) provePerm1 _ _ HNil = Just (Refl) provePerm1 p rtop@SNat (HCons sn@SNat perm) | Just Refl <- provePerm1 p rtop perm = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl _ -> Nothing | otherwise = Nothing provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True) provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> case cmpNat i n of EQI -> Just Refl LTI | Refl <- lemCount i n , Just Refl <- provePerm2 (SNat @(i + 1)) n perm -> checkElem i perm | otherwise -> Nothing GTI -> error "unreachable" where checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True) checkElem _ HNil = Nothing checkElem i@SNat (HCons k@SNat perm :: HList SNat is') = case sameNat i k of Just Refl -> Just Refl Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl | otherwise -> Nothing provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh class KnownNatList l where makeNatList :: HList SNat l instance KnownNatList '[] where makeNatList = HNil instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) | Dict <- lemKnownNatRankSSX ssh , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) -- | The list argument gives indices into the original dimension list. -- -- The permutation (the list) must have length <= @n@. If it is longer, this -- function throws. transposeUntyped :: forall n sh a. SNat n -> StaticShX sh -> [Int] -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a transposeUntyped sn ssh perm (XArray arr) | length perm <= fromSNat' sn , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) | otherwise = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) sumFull :: (Storable a, NumElt a) => XArray sh a -> a sumFull (XArray arr) = S.unScalar $ numEltSum1Inner (SNat @0) $ S.fromVector [product (S.shapeL arr)] $ S.toVector arr sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) sh'F = flattenSh sh' :$% ZSX ssh'F = staticShapeFrom sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F , let sn = snatLengthListX (let StaticShX l = ssh in l) = XArray (numEltSum1Inner sn arr') in go $ transpose2 ssh'F ssh $ reshapePartial ssh' ssh sh'F $ transpose2 ssh ssh' $ arr sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = flattenSh sh :$% ZSX in sumInner ssh' (staticShapeFrom shF) $ transpose2 (staticShapeFrom shF) ssh' $ reshapePartial ssh ssh' shF $ arr fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromListOuter ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of SKnown m :!% _ | fromSNat' m /= length l -> error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (fromSNat' m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] toListOuter (XArray arr) = case S.shapeL arr of 0 : _ -> [] _ -> coerce (ORB.toList (S.unravel arr)) fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a fromList1 ssh l = let n = length l in case ssh of SKnown m :!% _ | fromSNat' m /= n -> error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ "does not match the type (" ++ show (fromSNat' m) ++ ")" _ -> XArray (S.fromVector [n] (VS.fromListN n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 = XArray (S.reshape (shapeLshape sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthStaticShX ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)]))