diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 324 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith.hs | 435 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Foreign.hs | 55 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists.hs | 78 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 82 | 
5 files changed, 164 insertions, 810 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 712c5f1..0870789 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -60,7 +60,11 @@ import Unsafe.Coerce  import Data.Array.Mixed  import qualified Data.Array.Mixed as X -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Types  -- Invariant in the API @@ -123,19 +127,19 @@ lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict  ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)  ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n  lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)  lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) -lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n  lemRankReplicate SZ = Refl  lemRankReplicate (SS (n :: SNat nm1)) -  | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 +  | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1    , Refl <- lemRankReplicate n    = Refl -lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh  lemRankMapJust ZSS = Refl  lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl @@ -146,9 +150,9 @@ lemReplicatePlusApp sn _ _ = go sn      go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a      go SZ = Refl      go (SS (n :: SNat n'm1)) -      | Refl <- X.lemReplicateSucc @a @n'm1 +      | Refl <- lemReplicateSucc @a @n'm1        , Refl <- go n -      = sym (X.lemReplicateSucc @a @(n'm1 + m)) +      = sym (lemReplicateSucc @a @(n'm1 + m))  lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True  lemLeqPlus _ _ _ = Refl @@ -156,17 +160,17 @@ lemLeqPlus _ _ _ = Refl  lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True  lemLeqSuccSucc _ _ = unsafeCoerce Refl -lemDropLenApp :: X.Rank l1 <= X.Rank l2 +lemDropLenApp :: Rank l1 <= Rank l2                => Proxy l1 -> Proxy l2 -> Proxy rest -              -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) +              -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)  lemDropLenApp _ _ _ = unsafeCoerce Refl -lemTakeLenApp :: X.Rank l1 <= X.Rank l2 +lemTakeLenApp :: Rank l1 <= Rank l2                => Proxy l1 -> Proxy l2 -> Proxy rest -              -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) +              -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)  lemTakeLenApp _ _ _ = unsafeCoerce Refl -srankSh :: ShX sh f -> SNat (X.Rank sh) +srankSh :: ShX sh f -> SNat (Rank sh)  srankSh ZSX = SNat  srankSh (_ :$% sh) | SNat <- srankSh sh = SNat @@ -585,11 +589,11 @@ class Elt a where           -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)           -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a -  mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 +  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a -  mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) -             => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a +  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) +             => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a    -- ====== PRIVATE METHODS ====== -- @@ -635,20 +639,20 @@ class Elt a => KnownElt a where  instance Storable a => Elt (Primitive a) where    mshape (M_Primitive sh _) = sh    mindex (M_Primitive _ a) i = Primitive (X.index a i) -  mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) +  mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)    mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)    mfromListOuter l@(arr1 :| _) =      let sh = SUnknown (length l) :$% mshape arr1 -    in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) -  mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr) +    in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) +  mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)    mlift :: forall sh1 sh2.             StaticShX sh2          -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)          -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)    mlift ssh2 f (M_Primitive _ a) -    | Refl <- X.lemAppNil @sh1 -    , Refl <- X.lemAppNil @sh2 +    | Refl <- lemAppNil @sh1 +    , Refl <- lemAppNil @sh2      , let result = f ZKX a      = M_Primitive (X.shape ssh2 result) result @@ -657,36 +661,36 @@ instance Storable a => Elt (Primitive a) where           -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)           -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)    mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) -    | Refl <- X.lemAppNil @sh1 -    , Refl <- X.lemAppNil @sh2 -    , Refl <- X.lemAppNil @sh3 +    | Refl <- lemAppNil @sh1 +    , Refl <- lemAppNil @sh2 +    , Refl <- lemAppNil @sh3      , let result = f ZKX a b      = M_Primitive (X.shape ssh3 result) result -  mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 +  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)    mcast ssh1 sh2 _ (M_Primitive sh1' arr) = -    let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' -    in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) +    let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' +    in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)    mtranspose perm (M_Primitive sh arr) = -    M_Primitive (X.shPermutePrefix perm sh) -                (X.transpose (X.staticShapeFrom sh) perm arr) +    M_Primitive (shxPermutePrefix perm sh) +                (X.transpose (ssxFromShape sh) perm arr)    mshapeTree _ = ()    mshapeTreeEq _ () () = True    mshapeTreeEmpty _ () = False    mshowShapeTree _ () = "()" -  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x +  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x    -- TODO: this use of toVector is suboptimal    mvecsWritePartial      :: forall sh' sh s.         IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()    mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do -    let arrsh = X.shape (X.staticShapeFrom sh') arr -        offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh)) -    VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr) +    let arrsh = X.shape (ssxFromShape sh') arr +        offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) +    VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)    mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v @@ -701,7 +705,7 @@ deriving via Primitive () instance Elt ()  instance Storable a => KnownElt (Primitive a) where    memptyArray sh = M_Primitive sh (X.empty sh) -  mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) +  mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)    mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0  -- [PRIMITIVE ELEMENT TYPES LIST] @@ -755,7 +759,7 @@ instance Elt a => Elt (Mixed sh' a) where    -- moverlongShape method, a prefix of which is mshape.    mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh    mshape (M_Nest sh arr) -    = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr)) +    = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))    mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a    mindex (M_Nest _ arr) i = mindexPartial arr i @@ -763,8 +767,8 @@ instance Elt a => Elt (Mixed sh' a) where    mindexPartial :: forall sh1 sh2.                     Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)    mindexPartial (M_Nest sh arr) i -    | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) +    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') +    = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)    mscalar = M_Nest ZSX @@ -773,95 +777,95 @@ instance Elt a => Elt (Mixed sh' a) where      M_Nest (SUnknown (length l) :$% mshape arr)             (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) -  mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr) +  mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)    mlift :: forall sh1 sh2.             StaticShX sh2          -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)          -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)    mlift ssh2 f (M_Nest sh1 arr) = -    let result = mlift (X.ssxAppend ssh2 ssh') f' arr -        (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result) +    let result = mlift (ssxAppend ssh2 ssh') f' arr +        (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)      in M_Nest sh2 result      where -      ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr))) +      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))        f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b        f' sshT -        | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        = f (X.ssxAppend ssh' sshT) +        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        = f (ssxAppend ssh' sshT)    mlift2 :: forall sh1 sh2 sh3.              StaticShX sh3           -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)           -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)    mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = -    let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2 -        (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result) +    let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 +        (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)      in M_Nest sh3 result      where -      ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1))) +      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))        f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b        f' sshT -        | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) -        = f (X.ssxAppend ssh' sshT) +        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) +        = f (ssxAppend ssh' sshT) -  mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2 +  mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2          => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)    mcast ssh1 sh2 _ (M_Nest sh1T arr) -    | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') -    , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') -    = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T -      in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) +    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') +    , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') +    = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T +      in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) -  mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) -             => HList SNat is -> Mixed sh (Mixed sh' a) -             -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) +  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) +             => Perm is -> Mixed sh (Mixed sh' a) +             -> Mixed (PermutePrefix is sh) (Mixed sh' a)    mtranspose perm (M_Nest sh arr) -    | let sh' = X.shDropSh @sh @sh' (mshape arr) sh -    , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') -    , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) -    , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') +    | let sh' = shxDropSh @sh @sh' (mshape arr) sh +    , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') +    , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) +    , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')      , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')      , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') -    = M_Nest (X.shPermutePrefix perm sh) +    = M_Nest (shxPermutePrefix perm sh)               (mtranspose perm arr)    mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) -  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) +  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 -  mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t +  mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t    mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" -  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs +  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs    mvecsWritePartial :: forall sh1 sh2 s.                         IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)                      -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)                      -> ST s ()    mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) -    | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs +    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') +    = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs -  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs +  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs  instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -  memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) +  memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))    mvecsUnsafeNew sh example -    | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) -    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) +    | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) +    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))      where        sh' = mshape example -  mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +  mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)  -- | Create an array given a size and a function that computes the element at a @@ -882,10 +886,10 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where  -- array. The type of 'mgenerate' allows this requirement to be broken very  -- easily, hence the runtime check.  mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case X.enumShape sh of +mgenerate sh f = case shxEnum sh of    [] -> memptyArray sh    firstidx : restidxs -> -    let firstelem = f (X.zeroIxX' sh) +    let firstelem = f (ixxZero' sh)          shapetree = mshapeTree firstelem      in if mshapeTreeEmpty (Proxy @a) shapetree           then memptyArray sh @@ -905,28 +909,28 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a)              => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)  msumOuter1P (M_Primitive (n :$% sh) arr) =    let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX -  in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) +  in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)  msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)             => Mixed (n : sh) a -> Mixed sh a  msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive  mappend :: forall n m sh a. Elt a -        => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a +        => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a  mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2    where      sn :$% sh = mshape arr1      sm :$% _ = mshape arr2 -    ssh = X.staticShapeFrom sh -    snm :: SMayNat () SNat (X.AddMaybe n m) +    ssh = ssxFromShape sh +    snm :: SMayNat () SNat (AddMaybe n m)      snm = case (sn, sm) of              (SUnknown{}, _) -> SUnknown ()              (SKnown{}, SUnknown{}) -> SUnknown () -            (SKnown n, SKnown m) -> SKnown (X.plusSNat n m) +            (SKnown n, SKnown m) -> SKnown (snatPlus n m)      f :: forall sh' b. Storable b -      => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b -    f ssh' = X.append (X.ssxAppend ssh ssh') +      => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b +    f ssh' = X.append (ssxAppend ssh ssh')  mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)  mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) @@ -971,9 +975,9 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)           -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))           -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)  mrerankP ssh sh2 f (M_Primitive sh arr) = -  let sh1 = shDropSSX sh ssh -  in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2) -                 (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2) +  let sh1 = shxDropSSX sh ssh +  in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) +                 (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)                             (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)                             arr) @@ -988,10 +992,10 @@ mrerank ssh sh2 f (toPrimitive -> arr) =  mreplicate :: forall sh sh' a. Elt a             => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a  mreplicate sh arr = -  let ssh' = X.staticShapeFrom (mshape arr) -  in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') +  let ssh' = ssxFromShape (mshape arr) +  in mlift (ssxAppend (ssxFromShape sh) ssh')             (\(sshT :: StaticShX shT) -> -              case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of +              case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of                  Refl -> X.replicate sh (ssxAppend ssh' sshT))             arr @@ -1005,18 +1009,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)  mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a  mslice i n arr =    let _ :$% sh = mshape arr -  in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr +  in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr  msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr  mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr  mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a  mreshape sh' arr = -  mlift (X.staticShapeFrom sh') -        (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh') +  mlift (ssxFromShape sh') +        (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')          arr  miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a @@ -1095,26 +1099,26 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where    log1pexp = mliftNumElt1 floatEltLog1pexp    log1mexp = mliftNumElt1 floatEltLog1mexp -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a  mtoRanked arr -  | Refl <- X.lemAppNil @sh -  , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat)) +  | Refl <- lemAppNil @sh +  , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))    , Refl <- lemRankReplicate (srankSh (mshape arr)) -  = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) +  = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)    where -    convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing) +    convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)      convSh ZSX = ZSX      convSh (smn :$% (sh :: IShX sh'T)) -      | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T) +      | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)        = SUnknown (fromSMayNat' smn) :$% convSh sh -mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh') +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')                => Mixed sh a -> ShS sh' -> Shaped sh' a  mcastToShaped arr targetsh -  | Refl <- X.lemAppNil @sh -  , Refl <- X.lemAppNil @(MapJust sh') +  | Refl <- lemAppNil @sh +  , Refl <- lemAppNil @(MapJust sh')    , Refl <- lemRankMapJust targetsh -  = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) +  = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -1418,7 +1422,7 @@ zeroIxR :: SNat n -> IIxR n  zeroIxR SZ = ZIR  zeroIxR (SS n) = 0 :.: zeroIxR n -ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) +ixCvtXR :: IIxX sh -> IIxR (Rank sh)  ixCvtXR ZIX = ZIR  ixCvtXR (n :.% idx) = n :.: ixCvtXR idx @@ -1429,7 +1433,7 @@ shCvtXR' ZSX =  shCvtXR' (n :$% (idx :: IShX sh))    | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =    castWith (subst2 (lem1 @sh Refl)) -    (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) +    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))    where      lem1 :: forall sh' n' k.              k : sh' :~: Replicate n' Nothing @@ -1443,13 +1447,13 @@ shCvtXR' (n :$% (idx :: IShX sh))  ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)  ixCvtRX ZIR = ZIX  ixCvtRX (n :.: (idx :: IxR m Int)) = -  castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) +  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))      (n :.% ixCvtRX idx)  shCvtRX :: IShR n -> IShX (Replicate n Nothing)  shCvtRX ZSR = ZSX  shCvtRX (n :$: (idx :: ShR m Int)) = -  castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) +  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))      (SUnknown n :$% shCvtRX idx)  shapeSizeR :: IShR n -> Int @@ -1506,7 +1510,7 @@ rsumOuter1P :: forall n a.                 (Storable a, NumElt a)              => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)  rsumOuter1P (Ranked arr) -  | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n    = Ranked (msumOuter1P arr)  rsumOuter1 :: forall n a. (NumElt a, PrimElt a) @@ -1559,7 +1563,7 @@ rappend :: forall n a. Elt a  rappend arr1 arr2    | sn@SNat <- snatFromShR (rshape arr1)    , Dict <- lemKnownReplicate sn -  , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n +  , Refl <- lemReplicateSucc @(Nothing @Nat) @n    = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))        arr1 arr2 @@ -1582,7 +1586,7 @@ rtoVector = coerce mtoVector  rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a  rfromListOuter l -  | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n    = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))  rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a @@ -1593,7 +1597,7 @@ rfromList1Prim l = Ranked (mfromList1Prim l)  rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]  rtoListOuter (Ranked arr) -  | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n    = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)  rtoList1 :: Elt a => Ranked 1 a -> [a] @@ -1677,7 +1681,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)  rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a  rslice i n arr -  | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n    = rlift (snatFromShR (rshape arr))            (\_ -> X.sliceU i n)            arr @@ -1686,7 +1690,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a  rrev1 arr =    rlift (snatFromShR (rshape arr))          (\(_ :: StaticShX sh') -> -          case X.lemReplicateSucc @(Nothing @Nat) @n of +          case lemReplicateSucc @(Nothing @Nat) @n of              Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))          arr @@ -1707,12 +1711,12 @@ rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing  rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)  rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)  rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) -rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a  rcastToShaped (Ranked arr) targetsh    | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh))    , Refl <- lemRankMapJust targetsh @@ -1809,7 +1813,7 @@ shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh  shapeSizeS :: ShS sh -> Int  shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh +shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh  sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh @@ -1838,14 +1842,14 @@ slift :: forall sh1 sh2 a. Elt a        => ShS sh2        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)        -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)  -- | See the documentation of 'mlift'.  slift2 :: forall sh1 sh2 sh3 a. Elt a         => ShS sh3         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)         -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)  ssumOuter1P :: forall sh n a. (Storable a, NumElt a)              => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -1855,28 +1859,28 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)             => Shaped (n : sh) a -> Shaped sh a  ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive -lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) -lemCommMapJustTakeLen HNil _ = Refl -lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl -lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" +lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemCommMapJustTakeLen PNil _ = Refl +lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" -lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) -lemCommMapJustDropLen HNil _ = Refl -lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl -lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" +lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemCommMapJustDropLen PNil _ = Refl +lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" -lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)  lemCommMapJustIndex SZ (_ :$$ _) = Refl  lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))    | Refl <- lemCommMapJustIndex i sh -  , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) -  , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') +  , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) +  , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')    = Refl  lemCommMapJustIndex _ ZSS = error "Index of empty" -lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) -lemCommMapJustPermute HNil _ = Refl -lemCommMapJustPermute (i `HCons` is) sh +lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemCommMapJustPermute PNil _ = Refl +lemCommMapJustPermute (i `PCons` is) sh    | Refl <- lemCommMapJustPermute is sh    , Refl <- lemCommMapJustIndex i sh    = Refl @@ -1885,53 +1889,53 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f  listsAppend ZS idx' = idx'  listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f -listsTakeLen HNil _ = ZS -listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh -listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLen PNil _ = ZS +listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh +listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLen HNil sh = sh -listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh -listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLen PNil sh = sh +listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh +listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f -listsPermute HNil _ = ZS -listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f  listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest  listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest -  | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') +  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')    = listsIndex p pT i sh rest  listsIndex _ _ _ ZS _ = error "Index into empty shape" -shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)  shsTakeLen = coerce (listsTakeLen @SNat) -shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)  shsPermute = coerce (listsPermute @SNat) -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT)  shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) -applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f  applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) -applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i  applyPermIxS = coerce (applyPermS @(Const i)) -applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)  applyPermShS = coerce (applyPermS @SNat) -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) -           => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) +           => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a  stranspose perm sarr@(Shaped arr)    | Refl <- lemRankMapJust (sshape sarr)    , Refl <- lemCommMapJustTakeLen perm (sshape sarr)    , Refl <- lemCommMapJustDropLen perm (sshape sarr)    , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) -  , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) +  , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))    = Shaped (mtranspose perm arr)  sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a @@ -1969,7 +1973,7 @@ stoList1 = map sunScalar . stoListOuter  sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a  sfromListPrim sn l -  | Refl <- X.lemAppNil @'[Just n] +  | Refl <- lemAppNil @'[Just n]    = let ssh = SUnknown () :!% ZKX          xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)      in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr @@ -1989,7 +1993,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)  srerankP sh sh2 f sarr@(Shaped arr)    | Refl <- lemCommMapJustApp sh (Proxy @sh1)    , Refl <- lemCommMapJustApp sh (Proxy @sh2) -  = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh)))) +  = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))                       (shCvtSX sh2)                       (\a -> let Shaped r = f (Shaped a) in r)                       arr) @@ -2033,12 +2037,12 @@ sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)  sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)  sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)  sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) -stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a  stoRanked sarr@(Shaped arr)    | Refl <- lemRankMapJust (sshape sarr)    = mtoRanked arr diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs deleted file mode 100644 index 95fcfcf..0000000 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ /dev/null @@ -1,435 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Arith where - -import Control.Monad (forM, guard) -import qualified Data.Array.Internal as OI -import qualified Data.Array.Internal.RankedG as RG -import qualified Data.Array.Internal.RankedS as RS -import Data.Bits -import Data.Int -import Data.List (sort) -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.C.Types -import Foreign.Ptr -import Foreign.Storable (Storable) -import GHC.TypeLits -import Language.Haskell.TH -import System.IO.Unsafe - -import Data.Array.Nested.Internal.Arith.Foreign -import Data.Array.Nested.Internal.Arith.Lists - - -liftVEltwise1 :: Storable a -              => SNat n -              -> (VS.Vector a -> VS.Vector a) -              -> RS.Array n a -> RS.Array n a -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) -  | Just prefixSz <- stridesDense sh strides = -      let vec' = f (VS.slice offset prefixSz vec) -      in RS.A (RG.A sh (OI.T strides 0 vec')) -  | otherwise = RS.fromVector sh (f (RS.toVector arr)) - -liftVEltwise2 :: Storable a -              => SNat n -              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) -              -> RS.Array n a -> RS.Array n a -> RS.Array n a -liftVEltwise2 SNat f -    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) -    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) -  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 -  | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs -  | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of -      (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars -        let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) -        in RS.A (RG.A sh1 (OI.T strides1 0 vec')) -      (Just 1, Just n) ->  -- scalar * dense -        RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) -      (Just n, Just 1) ->  -- dense * scalar -        RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) -      (_, _) ->  -- fallback case -        RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) - --- | Given the shape vector and the stride vector, return whether this vector --- of strides uses a dense prefix of its backing array. If so, the number of --- elements in this prefix is returned. --- This excludes any offset. -stridesDense :: [Int] -> [Int] -> Maybe Int -stridesDense sh _ | any (<= 0) sh = Just 0 -stridesDense sh str = -  -- sort dimensions on their stride, ascending, dropping any zero strides -  case dropWhile ((== 0) . fst) (sort (zip str sh)) of -    [] -> Just 1 -    (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' -    _ -> Nothing  -- if the smallest stride is not 1, it will never be dense -  where -    -- Given size of currently densely covered region at beginning of the -    -- array, the remaining shape vector and the corresponding remaining stride -    -- vector, return whether this all together covers a dense prefix of the -    -- array. If it does, return the number of elements in this prefix. -    checkCover :: Int -> [Int] -> [Int] -> Maybe Int -    checkCover block [] [] = Just block -    checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' -    checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" - -{-# NOINLINE vectorOp1 #-} -vectorOp1 :: forall a b. Storable a -          => (Ptr a -> Ptr b) -          -> (Int64 -> Ptr b -> Ptr b -> IO ()) -          -> VS.Vector a -> VS.Vector a -vectorOp1 ptrconv f v = unsafePerformIO $ do -  outv <- VSM.unsafeNew (VS.length v) -  VSM.unsafeWith outv $ \poutv -> -    VS.unsafeWith v $ \pv -> -      f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) -  VS.unsafeFreeze outv - --- | If two vectors are given, assumes that they have the same length. -{-# NOINLINE vectorOp2 #-} -vectorOp2 :: forall a b. Storable a -          => (a -> b) -          -> (Ptr a -> Ptr b) -          -> (a -> a -> a) -          -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- sv -          -> (Int64 -> Ptr b -> Ptr b -> b -> IO ())  -- vs -          -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ())  -- vv -          -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a -vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases -  (Left x) (Left y) -> VS.singleton (fss x y) - -  (Left x) (Right vy) -> -    unsafePerformIO $ do -      outv <- VSM.unsafeNew (VS.length vy) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith vy $ \pvy -> -          fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) -      VS.unsafeFreeze outv - -  (Right vx) (Left y) -> -    unsafePerformIO $ do -      outv <- VSM.unsafeNew (VS.length vx) -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith vx $ \pvx -> -          fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) -      VS.unsafeFreeze outv - -  (Right vx) (Right vy) -    | VS.length vx == VS.length vy -> -        unsafePerformIO $ do -          outv <- VSM.unsafeNew (VS.length vx) -          VSM.unsafeWith outv $ \poutv -> -            VS.unsafeWith vx $ \pvx -> -              VS.unsafeWith vy $ \pvy -> -                fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) -          VS.unsafeFreeze outv -    | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) - --- TODO: test all the weird cases of this function --- | Reduce along the inner dimension -{-# NOINLINE vectorRedInnerOp #-} -vectorRedInnerOp :: forall a b n. (Num a, Storable a) -                 => SNat n -                 -> (a -> b) -                 -> (Ptr a -> Ptr b) -                 -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant -                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel -                 -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) -  | null sh = error "unreachable" -  | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) -  | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) -  -- now the input array is nonempty -  | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) -  | last strides == 0 = -      liftVEltwise1 sn -        (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) -        (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) -  -- now there is useful work along the inner dimension -  | otherwise = -      let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those -          (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) -          ndimsF = length shF -      in unsafePerformIO $ do -           outv <- VSM.unsafeNew (product (init shF)) -           VSM.unsafeWith outv $ \poutv -> -             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> -               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> -                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> -                   fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) -           RS.fromVector (init sh) <$> VS.unsafeFreeze outv - -flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -       ->  Int64 -> Ptr a -> Ptr a -> a -> IO () -flipOp f n out v s = f n out s v - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          cnamebase = "c_binary_" ++ atCName arithtype -          c_ss = varE (aboNumOp arithop) -          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          cnamebase = "c_fbinary_" ++ atCName arithtype -          c_ss = varE (afboNumOp arithop) -          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    fmap concat . forM [minBound..maxBound] $ \arithop -> do -      let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) -          c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) -          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) -      sequence [SigD name <$> -                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] -               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] -                   return $ FunD name [Clause [] (NormalB body) []]]) - --- This branch is ostensibly a runtime branch, but will (hopefully) be --- constant-folded away by GHC. -intWidBranch1 :: forall i n. (FiniteBits i, Storable i) -              => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -              -> (SNat n -> RS.Array n i -> RS.Array n i) -intWidBranch1 f32 f64 sn -  | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) -  | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) -  | otherwise = error "Unsupported Int width" - -intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) -              => (i -> i -> i)  -- ss -                 -- int32 -              -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- sv -              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ())  -- vs -              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- vv -                 -- int64 -              -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- sv -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs -              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv -              -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) -intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn -  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) -  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) -  | otherwise = error "Unsupported Int width" - -intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) -                => -- int32 -                   (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel -                   -- int64 -                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant -                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel -                -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed fsc32 fred32 fsc64 fred64 sn -  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 -  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 -  | otherwise = error "Unsupported Int width" - -class NumElt a where -  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a -  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a -  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a -  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - -instance NumElt Int32 where -  numEltAdd = addVectorInt32 -  numEltSub = subVectorInt32 -  numEltMul = mulVectorInt32 -  numEltNeg = negVectorInt32 -  numEltAbs = absVectorInt32 -  numEltSignum = signumVectorInt32 -  numEltSum1Inner = sum1VectorInt32 -  numEltProduct1Inner = product1VectorInt32 - -instance NumElt Int64 where -  numEltAdd = addVectorInt64 -  numEltSub = subVectorInt64 -  numEltMul = mulVectorInt64 -  numEltNeg = negVectorInt64 -  numEltAbs = absVectorInt64 -  numEltSignum = signumVectorInt64 -  numEltSum1Inner = sum1VectorInt64 -  numEltProduct1Inner = product1VectorInt64 - -instance NumElt Float where -  numEltAdd = addVectorFloat -  numEltSub = subVectorFloat -  numEltMul = mulVectorFloat -  numEltNeg = negVectorFloat -  numEltAbs = absVectorFloat -  numEltSignum = signumVectorFloat -  numEltSum1Inner = sum1VectorFloat -  numEltProduct1Inner = product1VectorFloat - -instance NumElt Double where -  numEltAdd = addVectorDouble -  numEltSub = subVectorDouble -  numEltMul = mulVectorDouble -  numEltNeg = negVectorDouble -  numEltAbs = absVectorDouble -  numEltSignum = signumVectorDouble -  numEltSum1Inner = sum1VectorDouble -  numEltProduct1Inner = product1VectorDouble - -instance NumElt Int where -  numEltAdd = intWidBranch2 @Int (+) -                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) -                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) -  numEltSub = intWidBranch2 @Int (-) -                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) -                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) -  numEltMul = intWidBranch2 @Int (*) -                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) -                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) -  numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) -  numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) -  numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) -  numEltSum1Inner = intWidBranchRed @Int -                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) -                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) -  numEltProduct1Inner = intWidBranchRed @Int -                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) -                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -instance NumElt CInt where -  numEltAdd = intWidBranch2 @CInt (+) -                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) -                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) -  numEltSub = intWidBranch2 @CInt (-) -                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) -                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) -  numEltMul = intWidBranch2 @CInt (*) -                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) -                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) -  numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) -  numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) -  numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) -  numEltSum1Inner = intWidBranchRed @CInt -                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) -                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) -  numEltProduct1Inner = intWidBranchRed @CInt -                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) -                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -class FloatElt a where -  floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a -  floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a -  floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a - -instance FloatElt Float where -  floatEltDiv = divVectorFloat -  floatEltPow = powVectorFloat -  floatEltLogbase = logbaseVectorFloat -  floatEltRecip = recipVectorFloat -  floatEltExp = expVectorFloat -  floatEltLog = logVectorFloat -  floatEltSqrt = sqrtVectorFloat -  floatEltSin = sinVectorFloat -  floatEltCos = cosVectorFloat -  floatEltTan = tanVectorFloat -  floatEltAsin = asinVectorFloat -  floatEltAcos = acosVectorFloat -  floatEltAtan = atanVectorFloat -  floatEltSinh = sinhVectorFloat -  floatEltCosh = coshVectorFloat -  floatEltTanh = tanhVectorFloat -  floatEltAsinh = asinhVectorFloat -  floatEltAcosh = acoshVectorFloat -  floatEltAtanh = atanhVectorFloat -  floatEltLog1p = log1pVectorFloat -  floatEltExpm1 = expm1VectorFloat -  floatEltLog1pexp = log1pexpVectorFloat -  floatEltLog1mexp = log1mexpVectorFloat - -instance FloatElt Double where -  floatEltDiv = divVectorDouble -  floatEltPow = powVectorDouble -  floatEltLogbase = logbaseVectorDouble -  floatEltRecip = recipVectorDouble -  floatEltExp = expVectorDouble -  floatEltLog = logVectorDouble -  floatEltSqrt = sqrtVectorDouble -  floatEltSin = sinVectorDouble -  floatEltCos = cosVectorDouble -  floatEltTan = tanVectorDouble -  floatEltAsin = asinVectorDouble -  floatEltAcos = acosVectorDouble -  floatEltAtan = atanVectorDouble -  floatEltSinh = sinhVectorDouble -  floatEltCosh = coshVectorDouble -  floatEltTanh = tanhVectorDouble -  floatEltAsinh = asinhVectorDouble -  floatEltAcosh = acoshVectorDouble -  floatEltAtanh = atanhVectorDouble -  floatEltLog1p = log1pVectorDouble -  floatEltExpm1 = expm1VectorDouble -  floatEltLog1pexp = log1pexpVectorDouble -  floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs deleted file mode 100644 index ac83188..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ /dev/null @@ -1,55 +0,0 @@ -{-# LANGUAGE ForeignFunctionInterface #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Foreign where - -import Control.Monad -import Data.Int -import Data.Maybe -import Foreign.C.Types -import Foreign.Ptr -import Language.Haskell.TH - -import Data.Array.Nested.Internal.Arith.Lists - - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "binary_" ++ atCName arithtype -    sequence $ catMaybes -      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> -               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) -      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> -               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) -      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) -      ]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "fbinary_" ++ atCName arithtype -    sequence $ catMaybes -      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> -               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) -      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> -               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) -      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> -               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) -      ]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "unary_" ++ atCName arithtype -    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "funary_" ++ atCName arithtype -    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do -    let ttyp = conT (atType arithtype) -    let base = "reduce_" ++ atCName arithtype -    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> -      [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs deleted file mode 100644 index ce2836d..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ /dev/null @@ -1,78 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Lists where - -import Data.Char -import Data.Int -import Language.Haskell.TH - -import Data.Array.Nested.Internal.Arith.Lists.TH - - -data ArithType = ArithType -  { atType :: Name  -- ''Int32 -  , atCName :: String  -- "i32" -  } - -floatTypesList :: [ArithType] -floatTypesList = -  [ArithType ''Float "float" -  ,ArithType ''Double "double" -  ] - -typesList :: [ArithType] -typesList = -  [ArithType ''Int32 "i32" -  ,ArithType ''Int64 "i64" -  ] -  ++ floatTypesList - --- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) -$(genArithDataType Binop "ArithBOp") - -$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) -$(genArithEnumFun Binop ''ArithBOp "aboEnum") - -$(do clauses <- readArithLists Binop -                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] -                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) -                                                     [])) -                  return -     sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] -              ,return $ FunD (mkName "aboNumOp") clauses]) - - --- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) -$(genArithDataType FBinop "ArithFBOp") - -$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) -$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") - -$(do clauses <- readArithLists FBinop -                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] -                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) -                                                     [])) -                  return -     sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] -              ,return $ FunD (mkName "afboNumOp") clauses]) - - --- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) -$(genArithDataType Unop "ArithUOp") - -$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) -$(genArithEnumFun Unop ''ArithUOp "auoEnum") - - --- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) -$(genArithDataType FUnop "ArithFUOp") - -$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) -$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") - - --- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) -$(genArithDataType Redop "ArithRedOp") - -$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) -$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs deleted file mode 100644 index 7142dfa..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs +++ /dev/null @@ -1,82 +0,0 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} -module Data.Array.Nested.Internal.Arith.Lists.TH where - -import Control.Monad -import Control.Monad.IO.Class -import Data.Maybe -import Foreign.C.Types -import Language.Haskell.TH -import Language.Haskell.TH.Syntax -import Text.Read - - -data OpKind = Binop | FBinop | Unop | FUnop | Redop -  deriving (Show, Eq) - -readArithLists :: OpKind -               -> (String -> Int -> String -> Q a) -               -> ([a] -> Q r) -               -> Q r -readArithLists targetkind fop fcombine = do -  addDependentFile "cbits/arith_lists.h" -  lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" - -  mvals <- forM lns $ \line -> do -    if null (dropWhile (== ' ') line) -      then return Nothing -      else do let (kind, name, num, aux) = parseLine line -              if kind == targetkind -                then Just <$> fop name num aux -                else return Nothing - -  fcombine (catMaybes mvals) -  where -    parseLine s0 -      | ("LIST_", s1) <- splitAt 5 s0 -      , (kindstr, '(' : s2) <- break (== '(') s1 -      , (f1, ',' : s3) <- parseField s2 -      , (f2, ',' : s4) <- parseField s3 -      , (f3, ')' : _) <- parseField s4 -      , Just kind <- parseKind kindstr -      , let name = f1 -      , Just num <- readMaybe f2 -      , let aux = f3 -      = (kind, name, num, aux) -      | otherwise -      = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 - -    parseField s = break (`elem` ",)") (dropWhile (== ' ') s) - -    parseKind "BINOP" = Just Binop -    parseKind "FBINOP" = Just FBinop -    parseKind "UNOP" = Just Unop -    parseKind "FUNOP" = Just FUnop -    parseKind "REDOP" = Just Redop -    parseKind _ = Nothing - -genArithDataType :: OpKind -> String -> Q [Dec] -genArithDataType kind dtname = do -  cons <- readArithLists kind -            (\name _num _ -> return $ NormalC (mkName name) []) -            return -  return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] - -genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] -genArithNameFun kind dtname funname nametrans = do -  clauses <- readArithLists kind -               (\name _num _ -> return (Clause [ConP (mkName name) [] []] -                                               (NormalB (LitE (StringL (nametrans name)))) -                                               [])) -               return -  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) -         ,FunD (mkName funname) clauses] - -genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] -genArithEnumFun kind dtname funname = do -  clauses <- readArithLists kind -               (\name num _ -> return (Clause [ConP (mkName name) [] []] -                                              (NormalB (LitE (IntegerL (fromIntegral num)))) -                                              [])) -               return -  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) -         ,FunD (mkName funname) clauses] | 
