aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
commita65306ba5d80891b20ac86fa3a3242f9497751e6 (patch)
tree834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Nested
parentd8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff)
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs326
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs435
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs55
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs78
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists/TH.hs82
5 files changed, 165 insertions, 811 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 712c5f1..0870789 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -60,7 +60,11 @@ import Unsafe.Coerce
import Data.Array.Mixed
import qualified Data.Array.Mixed as X
-import Data.Array.Nested.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Types
-- Invariant in the API
@@ -123,19 +127,19 @@ lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
-lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate SZ = Refl
lemRankReplicate (SS (n :: SNat nm1))
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
, Refl <- lemRankReplicate n
= Refl
-lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh
+lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh
lemRankMapJust ZSS = Refl
lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
@@ -146,9 +150,9 @@ lemReplicatePlusApp sn _ _ = go sn
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS (n :: SNat n'm1))
- | Refl <- X.lemReplicateSucc @a @n'm1
+ | Refl <- lemReplicateSucc @a @n'm1
, Refl <- go n
- = sym (X.lemReplicateSucc @a @(n'm1 + m))
+ = sym (lemReplicateSucc @a @(n'm1 + m))
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
@@ -156,17 +160,17 @@ lemLeqPlus _ _ _ = Refl
lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
lemLeqSuccSucc _ _ = unsafeCoerce Refl
-lemDropLenApp :: X.Rank l1 <= X.Rank l2
+lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
- -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest)
+ -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
lemDropLenApp _ _ _ = unsafeCoerce Refl
-lemTakeLenApp :: X.Rank l1 <= X.Rank l2
+lemTakeLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
- -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest)
+ -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerce Refl
-srankSh :: ShX sh f -> SNat (X.Rank sh)
+srankSh :: ShX sh f -> SNat (Rank sh)
srankSh ZSX = SNat
srankSh (_ :$% sh) | SNat <- srankSh sh = SNat
@@ -585,11 +589,11 @@ class Elt a where
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
- mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
- mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
- => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
-- ====== PRIVATE METHODS ====== --
@@ -635,20 +639,20 @@ class Elt a => KnownElt a where
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr)
+ in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
mlift ssh2 f (M_Primitive _ a)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
@@ -657,36 +661,36 @@ instance Storable a => Elt (Primitive a) where
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
- , Refl <- X.lemAppNil @sh3
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
- mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
- let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1'
- in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr)
+ let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
mtranspose perm (M_Primitive sh arr) =
- M_Primitive (X.shPermutePrefix perm sh)
- (X.transpose (X.staticShapeFrom sh) perm arr)
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShape sh) perm arr)
mshapeTree _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
:: forall sh' sh s.
IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (X.staticShapeFrom sh') arr
- offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh))
- VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr)
+ let arrsh = X.shape (ssxFromShape sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
@@ -701,7 +705,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArray sh = M_Primitive sh (X.empty sh)
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -755,7 +759,7 @@ instance Elt a => Elt (Mixed sh' a) where
-- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr))
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) i = mindexPartial arr i
@@ -763,8 +767,8 @@ instance Elt a => Elt (Mixed sh' a) where
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
@@ -773,95 +777,95 @@ instance Elt a => Elt (Mixed sh' a) where
M_Nest (SUnknown (length l) :$% mshape arr)
(mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
- mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr)
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift ssh2 f (M_Nest sh1 arr) =
- let result = mlift (X.ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result)
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr)))
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- = f (X.ssxAppend ssh' sshT)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
- let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result)
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1)))
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- = f (X.ssxAppend ssh' sshT)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
- mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2
+ mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
mcast ssh1 sh2 _ (M_Nest sh1T arr)
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T
- in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
-
- mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
- => HList SNat is -> Mixed sh (Mixed sh' a)
- -> Mixed (X.PermutePrefix is sh) (Mixed sh' a)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = X.shDropSh @sh @sh' (mshape arr) sh
- , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh')
- , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh'))
- , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
, Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
, Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- = M_Nest (X.shPermutePrefix perm sh)
+ = M_Nest (shxPermutePrefix perm sh)
(mtranspose perm arr)
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr)))))
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s.
IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
- memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh'))))
+ memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh')))
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
where
sh' = mshape example
- mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-- | Create an array given a size and a function that computes the element at a
@@ -882,10 +886,10 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case X.enumShape sh of
+mgenerate sh f = case shxEnum sh of
[] -> memptyArray sh
firstidx : restidxs ->
- let firstelem = f (X.zeroIxX' sh)
+ let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
in if mshapeTreeEmpty (Proxy @a) shapetree
then memptyArray sh
@@ -905,28 +909,28 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1P (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr)
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
mappend :: forall n m sh a. Elt a
- => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
where
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
- ssh = X.staticShapeFrom sh
- snm :: SMayNat () SNat (X.AddMaybe n m)
+ ssh = ssxFromShape sh
+ snm :: SMayNat () SNat (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
- (SKnown n, SKnown m) -> SKnown (X.plusSNat n m)
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
f :: forall sh' b. Storable b
- => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
- f ssh' = X.append (X.ssxAppend ssh ssh')
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
@@ -971,9 +975,9 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
-> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
-> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shDropSSX sh ssh
- in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2)
+ let sh1 = shxDropSSX sh ssh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
(\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
arr)
@@ -988,10 +992,10 @@ mrerank ssh sh2 f (toPrimitive -> arr) =
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
mreplicate sh arr =
- let ssh' = X.staticShapeFrom (mshape arr)
- in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh')
+ let ssh' = ssxFromShape (mshape arr)
+ in mlift (ssxAppend (ssxFromShape sh) ssh')
(\(sshT :: StaticShX shT) ->
- case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
@@ -1005,18 +1009,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
let _ :$% sh = mshape arr
- in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr
+ in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr
+msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr
+mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
mreshape sh' arr =
- mlift (X.staticShapeFrom sh')
- (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh')
+ mlift (ssxFromShape sh')
+ (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
arr
miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
@@ -1095,26 +1099,26 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
log1pexp = mliftNumElt1 floatEltLog1pexp
log1mexp = mliftNumElt1 floatEltLog1mexp
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
mtoRanked arr
- | Refl <- X.lemAppNil @sh
- , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat))
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))
, Refl <- lemRankReplicate (srankSh (mshape arr))
- = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
where
- convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing)
+ convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
convSh ZSX = ZSX
convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
= SUnknown (fromSMayNat' smn) :$% convSh sh
-mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh')
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> Mixed sh a -> ShS sh' -> Shaped sh' a
mcastToShaped arr targetsh
- | Refl <- X.lemAppNil @sh
- , Refl <- X.lemAppNil @(MapJust sh')
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(MapJust sh')
, Refl <- lemRankMapJust targetsh
- = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
+ = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -1418,7 +1422,7 @@ zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n
-ixCvtXR :: IIxX sh -> IIxR (X.Rank sh)
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
ixCvtXR ZIX = ZIR
ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
@@ -1429,7 +1433,7 @@ shCvtXR' ZSX =
shCvtXR' (n :$% (idx :: IShX sh))
| Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
castWith (subst2 (lem1 @sh Refl))
- (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
where
lem1 :: forall sh' n' k.
k : sh' :~: Replicate n' Nothing
@@ -1443,13 +1447,13 @@ shCvtXR' (n :$% (idx :: IShX sh))
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m))
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(n :.% ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m))
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(SUnknown n :$% shCvtRX idx)
shapeSizeR :: IShR n -> Int
@@ -1506,7 +1510,7 @@ rsumOuter1P :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked (msumOuter1P arr)
rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
@@ -1559,7 +1563,7 @@ rappend :: forall n a. Elt a
rappend arr1 arr2
| sn@SNat <- snatFromShR (rshape arr1)
, Dict <- lemKnownReplicate sn
- , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
arr1 arr2
@@ -1582,7 +1586,7 @@ rtoVector = coerce mtoVector
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
@@ -1593,7 +1597,7 @@ rfromList1Prim l = Ranked (mfromList1Prim l)
rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoListOuter (Ranked arr)
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
@@ -1677,7 +1681,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= rlift (snatFromShR (rshape arr))
(\_ -> X.sliceU i n)
arr
@@ -1686,7 +1690,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
rlift (snatFromShR (rshape arr))
(\(_ :: StaticShX sh') ->
- case X.lemReplicateSucc @(Nothing @Nat) @n of
+ case lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
arr
@@ -1707,12 +1711,12 @@ rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing
rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)
rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (srankSh (shCvtSX targetsh))
, Refl <- lemRankMapJust targetsh
@@ -1809,7 +1813,7 @@ shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
shapeSizeS :: ShS sh -> Int
shapeSizeS ZSS = 1
-shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh
+shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
@@ -1838,14 +1842,14 @@ slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
-slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr)
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
-- | See the documentation of 'mlift'.
slift2 :: forall sh1 sh2 sh3 a. Elt a
=> ShS sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
-slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2)
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
@@ -1855,28 +1859,28 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
-lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh)
-lemCommMapJustTakeLen HNil _ = Refl
-lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
-lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty"
+lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemCommMapJustTakeLen PNil _ = Refl
+lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
+lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty"
-lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh)
-lemCommMapJustDropLen HNil _ = Refl
-lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
-lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty"
+lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemCommMapJustDropLen PNil _ = Refl
+lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
+lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty"
-lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh)
+lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
lemCommMapJustIndex SZ (_ :$$ _) = Refl
lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
| Refl <- lemCommMapJustIndex i sh
- , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= Refl
lemCommMapJustIndex _ ZSS = error "Index of empty"
-lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh)
-lemCommMapJustPermute HNil _ = Refl
-lemCommMapJustPermute (i `HCons` is) sh
+lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemCommMapJustPermute PNil _ = Refl
+lemCommMapJustPermute (i `PCons` is) sh
| Refl <- lemCommMapJustPermute is sh
, Refl <- lemCommMapJustIndex i sh
= Refl
@@ -1885,53 +1889,53 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f
-listsTakeLen HNil _ = ZS
-listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh
-listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape"
+listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLen PNil _ = ZS
+listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh
+listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f
-listsDropLen HNil sh = sh
-listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh
-listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape"
+listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLen PNil sh = sh
+listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh
+listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f
-listsPermute HNil _ = ZS
-listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh)
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh)
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f
listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest
- | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listsIndex p pT i sh rest
listsIndex _ _ _ ZS _ = error "Index into empty shape"
-shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
shsTakeLen = coerce (listsTakeLen @SNat)
-shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT)
shsIndex pis pshT = coerce (listsIndex @SNat pis pshT)
-applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f
+applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh)
-applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i
+applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
applyPermIxS = coerce (applyPermS @(Const i))
-applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh)
+applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
applyPermShS = coerce (applyPermS @SNat)
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
- => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
stranspose perm sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
, Refl <- lemCommMapJustTakeLen perm (sshape sarr)
, Refl <- lemCommMapJustDropLen perm (sshape sarr)
, Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr))
- , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
+ , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
@@ -1969,7 +1973,7 @@ stoList1 = map sunScalar . stoListOuter
sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromListPrim sn l
- | Refl <- X.lemAppNil @'[Just n]
+ | Refl <- lemAppNil @'[Just n]
= let ssh = SUnknown () :!% ZKX
xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
@@ -1989,7 +1993,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
srerankP sh sh2 f sarr@(Shaped arr)
| Refl <- lemCommMapJustApp sh (Proxy @sh1)
, Refl <- lemCommMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh))))
+ = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
(shCvtSX sh2)
(\a -> let Shaped r = f (Shaped a) in r)
arr)
@@ -2033,12 +2037,12 @@ sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr)
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
-stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
deleted file mode 100644
index 95fcfcf..0000000
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ /dev/null
@@ -1,435 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TemplateHaskell #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Arith where
-
-import Control.Monad (forM, guard)
-import qualified Data.Array.Internal as OI
-import qualified Data.Array.Internal.RankedG as RG
-import qualified Data.Array.Internal.RankedS as RS
-import Data.Bits
-import Data.Int
-import Data.List (sort)
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
-import Foreign.C.Types
-import Foreign.Ptr
-import Foreign.Storable (Storable)
-import GHC.TypeLits
-import Language.Haskell.TH
-import System.IO.Unsafe
-
-import Data.Array.Nested.Internal.Arith.Foreign
-import Data.Array.Nested.Internal.Arith.Lists
-
-
-liftVEltwise1 :: Storable a
- => SNat n
- -> (VS.Vector a -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a
-liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
- | Just prefixSz <- stridesDense sh strides =
- let vec' = f (VS.slice offset prefixSz vec)
- in RS.A (RG.A sh (OI.T strides 0 vec'))
- | otherwise = RS.fromVector sh (f (RS.toVector arr))
-
-liftVEltwise2 :: Storable a
- => SNat n
- -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
-liftVEltwise2 SNat f
- arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
- arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
- | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
- | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
- (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
- let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
- in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
- (Just 1, Just n) -> -- scalar * dense
- RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))
- (Just n, Just 1) -> -- dense * scalar
- RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))
- (_, _) -> -- fallback case
- RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
-
--- | Given the shape vector and the stride vector, return whether this vector
--- of strides uses a dense prefix of its backing array. If so, the number of
--- elements in this prefix is returned.
--- This excludes any offset.
-stridesDense :: [Int] -> [Int] -> Maybe Int
-stridesDense sh _ | any (<= 0) sh = Just 0
-stridesDense sh str =
- -- sort dimensions on their stride, ascending, dropping any zero strides
- case dropWhile ((== 0) . fst) (sort (zip str sh)) of
- [] -> Just 1
- (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
- _ -> Nothing -- if the smallest stride is not 1, it will never be dense
- where
- -- Given size of currently densely covered region at beginning of the
- -- array, the remaining shape vector and the corresponding remaining stride
- -- vector, return whether this all together covers a dense prefix of the
- -- array. If it does, return the number of elements in this prefix.
- checkCover :: Int -> [Int] -> [Int] -> Maybe Int
- checkCover block [] [] = Just block
- checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
- checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
-
-{-# NOINLINE vectorOp1 #-}
-vectorOp1 :: forall a b. Storable a
- => (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> Ptr b -> IO ())
- -> VS.Vector a -> VS.Vector a
-vectorOp1 ptrconv f v = unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length v)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith v $ \pv ->
- f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv)
- VS.unsafeFreeze outv
-
--- | If two vectors are given, assumes that they have the same length.
-{-# NOINLINE vectorOp2 #-}
-vectorOp2 :: forall a b. Storable a
- => (a -> b)
- -> (Ptr a -> Ptr b)
- -> (a -> a -> a)
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv
- -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs
- -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv
- -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a
-vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
- (Left x) (Left y) -> VS.singleton (fss x y)
-
- (Left x) (Right vy) ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vy)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vy $ \pvy ->
- fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy)
- VS.unsafeFreeze outv
-
- (Right vx) (Left y) ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vx)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vx $ \pvx ->
- fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y)
- VS.unsafeFreeze outv
-
- (Right vx) (Right vy)
- | VS.length vx == VS.length vy ->
- unsafePerformIO $ do
- outv <- VSM.unsafeNew (VS.length vx)
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith vx $ \pvx ->
- VS.unsafeWith vy $ \pvy ->
- fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy)
- VS.unsafeFreeze outv
- | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
-
--- TODO: test all the weird cases of this function
--- | Reduce along the inner dimension
-{-# NOINLINE vectorRedInnerOp #-}
-vectorRedInnerOp :: forall a b n. (Num a, Storable a)
- => SNat n
- -> (a -> b)
- -> (Ptr a -> Ptr b)
- -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
- -> RS.Array (n + 1) a -> RS.Array n a
-vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
- | null sh = error "unreachable"
- | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0])
- | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty))
- -- now the input array is nonempty
- | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
- | last strides == 0 =
- liftVEltwise1 sn
- (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
- (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
- -- now there is useful work along the inner dimension
- | otherwise =
- let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
- (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
- ndimsF = length shF
- in unsafePerformIO $ do
- outv <- VSM.unsafeNew (product (init shF))
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
- VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
- VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
- fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
- RS.fromVector (init sh) <$> VS.unsafeFreeze outv
-
-flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
- -> Int64 -> Ptr a -> Ptr a -> a -> IO ()
-flipOp f n out v s = f n out s v
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_binary_" ++ atCName arithtype
- c_ss = varE (aboNumOp arithop)
- c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
- cnamebase = "c_fbinary_" ++ atCName arithtype
- c_ss = varE (afboNumOp arithop)
- c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
- c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
- sequence [SigD name <$>
- [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
- return $ FunD name [Clause [] (NormalB body) []]])
-
--- This branch is ostensibly a runtime branch, but will (hopefully) be
--- constant-folded away by GHC.
-intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
- => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
- -> (SNat n -> RS.Array n i -> RS.Array n i)
-intWidBranch1 f32 f64 sn
- | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
- | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
- | otherwise = error "Unsupported Int width"
-
-intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
- => (i -> i -> i) -- ss
- -- int32
- -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv
- -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs
- -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
- -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
-intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
- | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
- | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
- | otherwise = error "Unsupported Int width"
-
-intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i)
- => -- int32
- (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
-intWidBranchRed fsc32 fred32 fsc64 fred64 sn
- | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
- | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
- | otherwise = error "Unsupported Int width"
-
-class NumElt a where
- numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
- numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
- numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
-
-instance NumElt Int32 where
- numEltAdd = addVectorInt32
- numEltSub = subVectorInt32
- numEltMul = mulVectorInt32
- numEltNeg = negVectorInt32
- numEltAbs = absVectorInt32
- numEltSignum = signumVectorInt32
- numEltSum1Inner = sum1VectorInt32
- numEltProduct1Inner = product1VectorInt32
-
-instance NumElt Int64 where
- numEltAdd = addVectorInt64
- numEltSub = subVectorInt64
- numEltMul = mulVectorInt64
- numEltNeg = negVectorInt64
- numEltAbs = absVectorInt64
- numEltSignum = signumVectorInt64
- numEltSum1Inner = sum1VectorInt64
- numEltProduct1Inner = product1VectorInt64
-
-instance NumElt Float where
- numEltAdd = addVectorFloat
- numEltSub = subVectorFloat
- numEltMul = mulVectorFloat
- numEltNeg = negVectorFloat
- numEltAbs = absVectorFloat
- numEltSignum = signumVectorFloat
- numEltSum1Inner = sum1VectorFloat
- numEltProduct1Inner = product1VectorFloat
-
-instance NumElt Double where
- numEltAdd = addVectorDouble
- numEltSub = subVectorDouble
- numEltMul = mulVectorDouble
- numEltNeg = negVectorDouble
- numEltAbs = absVectorDouble
- numEltSignum = signumVectorDouble
- numEltSum1Inner = sum1VectorDouble
- numEltProduct1Inner = product1VectorDouble
-
-instance NumElt Int where
- numEltAdd = intWidBranch2 @Int (+)
- (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
- (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
- numEltSub = intWidBranch2 @Int (-)
- (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
- (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
- numEltMul = intWidBranch2 @Int (*)
- (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
- (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed @Int
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
- numEltProduct1Inner = intWidBranchRed @Int
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
-
-instance NumElt CInt where
- numEltAdd = intWidBranch2 @CInt (+)
- (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
- (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
- numEltSub = intWidBranch2 @CInt (-)
- (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
- (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
- numEltMul = intWidBranch2 @CInt (*)
- (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
- (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
- numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
- numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
- numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed @CInt
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
- numEltProduct1Inner = intWidBranchRed @CInt
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
-
-class FloatElt a where
- floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
- floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
- floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
-
-instance FloatElt Float where
- floatEltDiv = divVectorFloat
- floatEltPow = powVectorFloat
- floatEltLogbase = logbaseVectorFloat
- floatEltRecip = recipVectorFloat
- floatEltExp = expVectorFloat
- floatEltLog = logVectorFloat
- floatEltSqrt = sqrtVectorFloat
- floatEltSin = sinVectorFloat
- floatEltCos = cosVectorFloat
- floatEltTan = tanVectorFloat
- floatEltAsin = asinVectorFloat
- floatEltAcos = acosVectorFloat
- floatEltAtan = atanVectorFloat
- floatEltSinh = sinhVectorFloat
- floatEltCosh = coshVectorFloat
- floatEltTanh = tanhVectorFloat
- floatEltAsinh = asinhVectorFloat
- floatEltAcosh = acoshVectorFloat
- floatEltAtanh = atanhVectorFloat
- floatEltLog1p = log1pVectorFloat
- floatEltExpm1 = expm1VectorFloat
- floatEltLog1pexp = log1pexpVectorFloat
- floatEltLog1mexp = log1mexpVectorFloat
-
-instance FloatElt Double where
- floatEltDiv = divVectorDouble
- floatEltPow = powVectorDouble
- floatEltLogbase = logbaseVectorDouble
- floatEltRecip = recipVectorDouble
- floatEltExp = expVectorDouble
- floatEltLog = logVectorDouble
- floatEltSqrt = sqrtVectorDouble
- floatEltSin = sinVectorDouble
- floatEltCos = cosVectorDouble
- floatEltTan = tanVectorDouble
- floatEltAsin = asinVectorDouble
- floatEltAcos = acosVectorDouble
- floatEltAtan = atanVectorDouble
- floatEltSinh = sinhVectorDouble
- floatEltCosh = coshVectorDouble
- floatEltTanh = tanhVectorDouble
- floatEltAsinh = asinhVectorDouble
- floatEltAcosh = acoshVectorDouble
- floatEltAtanh = atanhVectorDouble
- floatEltLog1p = log1pVectorDouble
- floatEltExpm1 = expm1VectorDouble
- floatEltLog1pexp = log1pexpVectorDouble
- floatEltLog1mexp = log1mexpVectorDouble
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
deleted file mode 100644
index ac83188..0000000
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ /dev/null
@@ -1,55 +0,0 @@
-{-# LANGUAGE ForeignFunctionInterface #-}
-{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Nested.Internal.Arith.Foreign where
-
-import Control.Monad
-import Data.Int
-import Data.Maybe
-import Foreign.C.Types
-import Foreign.Ptr
-import Language.Haskell.TH
-
-import Data.Array.Nested.Internal.Arith.Lists
-
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- let base = "binary_" ++ atCName arithtype
- sequence $ catMaybes
- [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- let base = "fbinary_" ++ atCName arithtype
- sequence $ catMaybes
- [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- let base = "unary_" ++ atCName arithtype
- pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
-
-$(fmap concat . forM floatTypesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- let base = "funary_" ++ atCName arithtype
- pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
-
-$(fmap concat . forM typesList $ \arithtype -> do
- let ttyp = conT (atType arithtype)
- let base = "reduce_" ++ atCName arithtype
- pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
deleted file mode 100644
index ce2836d..0000000
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ /dev/null
@@ -1,78 +0,0 @@
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Nested.Internal.Arith.Lists where
-
-import Data.Char
-import Data.Int
-import Language.Haskell.TH
-
-import Data.Array.Nested.Internal.Arith.Lists.TH
-
-
-data ArithType = ArithType
- { atType :: Name -- ''Int32
- , atCName :: String -- "i32"
- }
-
-floatTypesList :: [ArithType]
-floatTypesList =
- [ArithType ''Float "float"
- ,ArithType ''Double "double"
- ]
-
-typesList :: [ArithType]
-typesList =
- [ArithType ''Int32 "i32"
- ,ArithType ''Int64 "i64"
- ]
- ++ floatTypesList
-
--- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
-$(genArithDataType Binop "ArithBOp")
-
-$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
-$(genArithEnumFun Binop ''ArithBOp "aboEnum")
-
-$(do clauses <- readArithLists Binop
- (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
- (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
- []))
- return
- sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
- ,return $ FunD (mkName "aboNumOp") clauses])
-
-
--- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
-$(genArithDataType FBinop "ArithFBOp")
-
-$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
-$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
-
-$(do clauses <- readArithLists FBinop
- (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
- (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
- []))
- return
- sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
- ,return $ FunD (mkName "afboNumOp") clauses])
-
-
--- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
-$(genArithDataType Unop "ArithUOp")
-
-$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
-$(genArithEnumFun Unop ''ArithUOp "auoEnum")
-
-
--- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
-$(genArithDataType FUnop "ArithFUOp")
-
-$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
-$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
-
-
--- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
-$(genArithDataType Redop "ArithRedOp")
-
-$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
-$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
deleted file mode 100644
index 7142dfa..0000000
--- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
+++ /dev/null
@@ -1,82 +0,0 @@
-{-# LANGUAGE TemplateHaskellQuotes #-}
-module Data.Array.Nested.Internal.Arith.Lists.TH where
-
-import Control.Monad
-import Control.Monad.IO.Class
-import Data.Maybe
-import Foreign.C.Types
-import Language.Haskell.TH
-import Language.Haskell.TH.Syntax
-import Text.Read
-
-
-data OpKind = Binop | FBinop | Unop | FUnop | Redop
- deriving (Show, Eq)
-
-readArithLists :: OpKind
- -> (String -> Int -> String -> Q a)
- -> ([a] -> Q r)
- -> Q r
-readArithLists targetkind fop fcombine = do
- addDependentFile "cbits/arith_lists.h"
- lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
-
- mvals <- forM lns $ \line -> do
- if null (dropWhile (== ' ') line)
- then return Nothing
- else do let (kind, name, num, aux) = parseLine line
- if kind == targetkind
- then Just <$> fop name num aux
- else return Nothing
-
- fcombine (catMaybes mvals)
- where
- parseLine s0
- | ("LIST_", s1) <- splitAt 5 s0
- , (kindstr, '(' : s2) <- break (== '(') s1
- , (f1, ',' : s3) <- parseField s2
- , (f2, ',' : s4) <- parseField s3
- , (f3, ')' : _) <- parseField s4
- , Just kind <- parseKind kindstr
- , let name = f1
- , Just num <- readMaybe f2
- , let aux = f3
- = (kind, name, num, aux)
- | otherwise
- = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
-
- parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
-
- parseKind "BINOP" = Just Binop
- parseKind "FBINOP" = Just FBinop
- parseKind "UNOP" = Just Unop
- parseKind "FUNOP" = Just FUnop
- parseKind "REDOP" = Just Redop
- parseKind _ = Nothing
-
-genArithDataType :: OpKind -> String -> Q [Dec]
-genArithDataType kind dtname = do
- cons <- readArithLists kind
- (\name _num _ -> return $ NormalC (mkName name) [])
- return
- return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
-
-genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
-genArithNameFun kind dtname funname nametrans = do
- clauses <- readArithLists kind
- (\name _num _ -> return (Clause [ConP (mkName name) [] []]
- (NormalB (LitE (StringL (nametrans name))))
- []))
- return
- return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
- ,FunD (mkName funname) clauses]
-
-genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
-genArithEnumFun kind dtname funname = do
- clauses <- readArithLists kind
- (\name num _ -> return (Clause [ConP (mkName name) [] []]
- (NormalB (LitE (IntegerL (fromIntegral num))))
- []))
- return
- return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
- ,FunD (mkName funname) clauses]