aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Base.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Base.hs')
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs27
1 files changed, 11 insertions, 16 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index ddd44bf..98f1241 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -90,8 +90,8 @@ instance Elt a => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a)
+ mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
mtoListOuter (M_Shaped arr)
@@ -136,7 +136,7 @@ instance Elt a => Elt (Shaped sh a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
@@ -172,10 +172,10 @@ instance Elt a => Elt (Shaped sh a) where
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArrayUnsafe i
+ memptyArrayUnsafe sh
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
@@ -203,15 +203,15 @@ instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
negate = liftShaped1 negate
abs = liftShaped1 abs
signum = liftShaped1 signum
- fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal"
+ fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim"
recip = liftShaped1 recip
(/) = liftShaped2 (/)
instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim"
exp = liftShaped1 exp
log = liftShaped1 log
sqrt = liftShaped1 sqrt
@@ -246,15 +246,10 @@ sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shsFromShX (mshape arr)
-- Needed already here, but re-exported in Data.Array.Nested.Convert.
-shsFromShX :: forall sh. IShX (MapJust sh) -> ShS sh
+shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh
shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
-shsFromShX (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
+shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) =
+ castWith (subst1 (sym (lemMapJustCons Refl))) $
n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerceRefl
shsFromShX (SUnknown _ :$% _) = error "impossible"