diff options
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Nested.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 57 |
2 files changed, 54 insertions, 6 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 148acf5..f383b99 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -1,8 +1,9 @@ {-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE PatternSynonyms #-} module Data.Array.Nested ( -- * Ranked arrays Ranked, - IxR(..), IIxR, + IxR, pattern (:.:), pattern ZIR, IIxR, rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, runScalar, rconstant, rfromList, rfromList1, rtoList, rtoList1, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 0582a14..e42de12 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -15,6 +15,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-| @@ -849,17 +850,63 @@ instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) +type ListR :: Type -> INat -> Type +data ListR i n where + ZR :: ListR i Z + (:::) :: forall n i. i -> ListR i n -> ListR i (S n) +deriving instance Show i => Show (ListR i n) +deriving instance Eq i => Eq (ListR i n) +infixr 3 ::: + -- | An index into a rank-typed array. type IxR :: Type -> INat -> Type -data IxR i n where - ZIR :: IxR i Z - (:.:) :: forall n i. i -> IxR i n -> IxR i (S n) -deriving instance Show i => Show (IxR i n) -deriving instance Eq i => Eq (IxR i n) +newtype IxR i n = IxR (ListR i n) + deriving (Show, Eq) + +pattern ZIR :: forall n i. () => n ~ Z => IxR i n +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. ((S n) ~ n1) + => i -> IxR i n -> IxR i n1 +pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) + where i :.: (IxR sh) = IxR (i ::: sh) +{-# COMPLETE ZIR, (:.:) #-} infixr 3 :.: +data UnconsIxRRes i n1 = + forall n. ((S n) ~ n1) => UnconsIxRRes (IxR i n) i +unconsIxR :: IxR i n1 -> Maybe (UnconsIxRRes i n1) +unconsIxR (IxR sh) = case sh of + i ::: sh' -> Just (UnconsIxRRes (IxR sh') i) + ZR -> Nothing + type IIxR = IxR Int +type StaticShapeR :: Type -> INat -> Type +newtype StaticShapeR i n = StaticShapeR (ListR i n) + deriving (Show, Eq) + +pattern ZSR :: forall n i. () => n ~ Z => StaticShapeR i n +pattern ZSR = StaticShapeR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. ((S n) ~ n1) + => i -> StaticShapeR i n -> StaticShapeR i n1 +pattern i :$: sh <- (unconsStaticShapeR -> Just (UnconsStaticShapeRRes sh i)) + where i :$: (StaticShapeR sh) = StaticShapeR (i ::: sh) +{-# COMPLETE ZSR, (:$:) #-} +infixr 3 :$: + +data UnconsStaticShapeRRes i n1 = + forall n. ((S n) ~ n1) => UnconsStaticShapeRRes (StaticShapeR i n) i +unconsStaticShapeR :: StaticShapeR i n1 -> Maybe (UnconsStaticShapeRRes i n1) +unconsStaticShapeR (StaticShapeR sh) = case sh of + i ::: sh' -> Just (UnconsStaticShapeRRes (StaticShapeR sh') i) + ZR -> Nothing + zeroIxR :: SINat n -> IIxR n zeroIxR SZ = ZIR zeroIxR (SS n) = 0 :.: zeroIxR n |