aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs78
1 files changed, 37 insertions, 41 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index ce18431..69c44ab 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -127,17 +127,6 @@ deriving instance Show (StaticShX sh)
infixr 3 :!$@
infixr 3 :!$?
--- | Evidence for the static part of a shape.
-type KnownShapeX :: [Maybe Nat] -> Constraint
-class KnownShapeX sh where
- knownShapeX :: StaticShX sh
-instance KnownShapeX '[] where
- knownShapeX = ZKSX
-instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = natSing :!$@ knownShapeX
-instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
- knownShapeX = () :!$? knownShapeX
-
type family Rank sh where
Rank '[] = 0
Rank (_ : sh) = 1 + Rank sh
@@ -162,6 +151,7 @@ completeShXzeros ZKSX = ZSX
completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh
completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh
+-- TODO: generalise all these things to arbitrary @i@
ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
ixAppend ZIX idx' = idx'
ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'
@@ -177,6 +167,15 @@ ixDrop sh ZIX = sh
ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx
ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx
+shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh'
+shDropIx sh ZIX = sh
+shDropIx (_ :$@ sh) (_ :.@ idx) = shDropIx sh idx
+shDropIx (_ :$? sh) (_ :.? idx) = shDropIx sh idx
+
+shTail :: IShX (n : sh) -> IShX sh
+shTail (_ :$@ sh) = sh
+shTail (_ :$? sh) = sh
+
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKSX sh' = sh'
ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh'
@@ -279,22 +278,22 @@ lemKnownNatRankSSX ZKSX = Dict
lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
-lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
-lemKnownShapeX ZKSX = Dict
-lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-
-lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
-lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'
- | Dict <- lemAppKnownShapeX ssh ssh'
- = Dict
-lemAppKnownShapeX (() :!$? ssh) ssh'
- | Dict <- lemAppKnownShapeX ssh ssh'
- = Dict
-
-shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh
-shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
+-- lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
+-- lemKnownShapeX ZKSX = Dict
+-- lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
+-- lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict
+
+-- lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
+-- lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh'
+-- lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'
+-- | Dict <- lemAppKnownShapeX ssh ssh'
+-- = Dict
+-- lemAppKnownShapeX (() :!$? ssh) ssh'
+-- | Dict <- lemAppKnownShapeX ssh ssh'
+-- = Dict
+
+shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
+shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
where
go :: StaticShX sh' -> [Int] -> IShX sh'
go ZKSX [] = ZSX
@@ -345,10 +344,10 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
-append :: forall n m sh a. (KnownShapeX sh, Storable a)
- => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
-append (XArray a) (XArray b)
- | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+append :: forall n m sh a. Storable a
+ => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
+append ssh (XArray a) (XArray b)
+ | Dict <- lemKnownNatRankSSX ssh
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
@@ -429,10 +428,6 @@ foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m
foldHList _ HNil = mempty
foldHList f (x `HCons` l) = f x <> foldHList f l
-class KnownNatList l where makeNatList :: HList SNat l
-instance KnownNatList '[] where makeNatList = HNil
-instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList
-
type family TakeLen ref l where
TakeLen '[] l = '[]
TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
@@ -485,15 +480,16 @@ ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest
ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
-- | The list argument gives indices into the original dimension list.
-transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
- => HList SNat is
+transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh)
+ => StaticShX sh
+ -> HList SNat is
-> XArray sh a
-> XArray (PermutePrefix is sh) a
-transpose perm (XArray arr)
- | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
- , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh))
+transpose ssh perm (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
, Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
- , Refl <- lemRankDropLen (knownShapeX @sh) perm
+ , Refl <- lemRankDropLen ssh perm
= let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
in XArray (S.transpose perm' arr)