diff options
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 78 |
1 files changed, 37 insertions, 41 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index ce18431..69c44ab 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -127,17 +127,6 @@ deriving instance Show (StaticShX sh) infixr 3 :!$@ infixr 3 :!$? --- | Evidence for the static part of a shape. -type KnownShapeX :: [Maybe Nat] -> Constraint -class KnownShapeX sh where - knownShapeX :: StaticShX sh -instance KnownShapeX '[] where - knownShapeX = ZKSX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing :!$@ knownShapeX -instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :!$? knownShapeX - type family Rank sh where Rank '[] = 0 Rank (_ : sh) = 1 + Rank sh @@ -162,6 +151,7 @@ completeShXzeros ZKSX = ZSX completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh +-- TODO: generalise all these things to arbitrary @i@ ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' @@ -177,6 +167,15 @@ ixDrop sh ZIX = sh ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx +shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' +shDropIx sh ZIX = sh +shDropIx (_ :$@ sh) (_ :.@ idx) = shDropIx sh idx +shDropIx (_ :$? sh) (_ :.? idx) = shDropIx sh idx + +shTail :: IShX (n : sh) -> IShX sh +shTail (_ :$@ sh) = sh +shTail (_ :$? sh) = sh + ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKSX sh' = sh' ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' @@ -279,22 +278,22 @@ lemKnownNatRankSSX ZKSX = Dict lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh -lemKnownShapeX ZKSX = Dict -lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict - -lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict -lemAppKnownShapeX (() :!$? ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict - -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh -shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) +-- lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh +-- lemKnownShapeX ZKSX = Dict +-- lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +-- lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict + +-- lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +-- lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' +-- lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' +-- | Dict <- lemAppKnownShapeX ssh ssh' +-- = Dict +-- lemAppKnownShapeX (() :!$? ssh) ssh' +-- | Dict <- lemAppKnownShapeX ssh ssh' +-- = Dict + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' go ZKSX [] = ZSX @@ -345,10 +344,10 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -append :: forall n m sh a. (KnownShapeX sh, Storable a) - => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. @@ -429,10 +428,6 @@ foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m foldHList _ HNil = mempty foldHList f (x `HCons` l) = f x <> foldHList f l -class KnownNatList l where makeNatList :: HList SNat l -instance KnownNatList '[] where makeNatList = HNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList - type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs @@ -485,15 +480,16 @@ ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) - => HList SNat is +transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) + => StaticShX sh + -> HList SNat is -> XArray sh a -> XArray (PermutePrefix is sh) a -transpose perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm - , Refl <- lemRankDropLen (knownShapeX @sh) perm + , Refl <- lemRankDropLen ssh perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) |