aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@gmail.com>2024-04-21 17:50:30 +0200
committerMikolaj Konarski <mikolaj.konarski@gmail.com>2024-04-21 18:12:26 +0200
commitb3c92786635568e652b98095c3d0db5b4ec312b2 (patch)
treee5c1ad29e5e87157e813d50eea61b959ca522e54 /src/Data/Array
parentd4397160c5c5476dc4d93a169b06f6a03f1dab02 (diff)
Flesh out ranked sized lists
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested.hs3
-rw-r--r--src/Data/Array/Nested/Internal.hs57
2 files changed, 54 insertions, 6 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 148acf5..f383b99 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -1,8 +1,9 @@
{-# LANGUAGE ExplicitNamespaces #-}
+{-# LANGUAGE PatternSynonyms #-}
module Data.Array.Nested (
-- * Ranked arrays
Ranked,
- IxR(..), IIxR,
+ IxR, pattern (:.:), pattern ZIR, IIxR,
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, runScalar,
rconstant, rfromList, rfromList1, rtoList, rtoList1,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 0582a14..e42de12 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -15,6 +15,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-|
@@ -849,17 +850,63 @@ instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+type ListR :: Type -> INat -> Type
+data ListR i n where
+ ZR :: ListR i Z
+ (:::) :: forall n i. i -> ListR i n -> ListR i (S n)
+deriving instance Show i => Show (ListR i n)
+deriving instance Eq i => Eq (ListR i n)
+infixr 3 :::
+
-- | An index into a rank-typed array.
type IxR :: Type -> INat -> Type
-data IxR i n where
- ZIR :: IxR i Z
- (:.:) :: forall n i. i -> IxR i n -> IxR i (S n)
-deriving instance Show i => Show (IxR i n)
-deriving instance Eq i => Eq (IxR i n)
+newtype IxR i n = IxR (ListR i n)
+ deriving (Show, Eq)
+
+pattern ZIR :: forall n i. () => n ~ Z => IxR i n
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. ((S n) ~ n1)
+ => i -> IxR i n -> IxR i n1
+pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
+ where i :.: (IxR sh) = IxR (i ::: sh)
+{-# COMPLETE ZIR, (:.:) #-}
infixr 3 :.:
+data UnconsIxRRes i n1 =
+ forall n. ((S n) ~ n1) => UnconsIxRRes (IxR i n) i
+unconsIxR :: IxR i n1 -> Maybe (UnconsIxRRes i n1)
+unconsIxR (IxR sh) = case sh of
+ i ::: sh' -> Just (UnconsIxRRes (IxR sh') i)
+ ZR -> Nothing
+
type IIxR = IxR Int
+type StaticShapeR :: Type -> INat -> Type
+newtype StaticShapeR i n = StaticShapeR (ListR i n)
+ deriving (Show, Eq)
+
+pattern ZSR :: forall n i. () => n ~ Z => StaticShapeR i n
+pattern ZSR = StaticShapeR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. ((S n) ~ n1)
+ => i -> StaticShapeR i n -> StaticShapeR i n1
+pattern i :$: sh <- (unconsStaticShapeR -> Just (UnconsStaticShapeRRes sh i))
+ where i :$: (StaticShapeR sh) = StaticShapeR (i ::: sh)
+{-# COMPLETE ZSR, (:$:) #-}
+infixr 3 :$:
+
+data UnconsStaticShapeRRes i n1 =
+ forall n. ((S n) ~ n1) => UnconsStaticShapeRRes (StaticShapeR i n) i
+unconsStaticShapeR :: StaticShapeR i n1 -> Maybe (UnconsStaticShapeRRes i n1)
+unconsStaticShapeR (StaticShapeR sh) = case sh of
+ i ::: sh' -> Just (UnconsStaticShapeRRes (StaticShapeR sh') i)
+ ZR -> Nothing
+
zeroIxR :: SINat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n