aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs6
-rw-r--r--ops/Data/Array/Strided/Array.hs2
-rw-r--r--ox-arrays.cabal1
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs2
-rw-r--r--test/Gen.hs4
-rw-r--r--test/Util.hs2
8 files changed, 9 insertions, 14 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 95e5af2..5802573 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -53,7 +53,7 @@ debugShow (Array sh strides offset vec) =
-- TODO: test all the cases of this thing with various input strides
-liftOpEltwise1 :: (Storable a, Storable b)
+liftOpEltwise1 :: Storable a
=> SNat n
-> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())
@@ -62,7 +62,7 @@ liftOpEltwise1 sn@SNat ptrconv cf_strided arr@(Array sh strides offset vec)
| Just (blockOff, blockSz) <- stridesDense sh offset strides =
if blockSz == 0
then Array sh (map (const 0) strides) 0 VS.empty
- else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [fromIntegral blockSz] [1] blockOff vec)
+ else let resvec = arrValues $ wrapUnary sn ptrconv cf_strided (Array [blockSz] [1] blockOff vec)
in Array sh strides (offset - blockOff) resvec
| otherwise = wrapUnary sn ptrconv cf_strided arr
@@ -673,7 +673,7 @@ intWidBranchRedFull fsc fred32 fred64 sn
| finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
| otherwise = error "Unsupported Int width"
-intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i)
=> -- int32
(forall b. b ~ Int32 => Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
-- int64
diff --git a/ops/Data/Array/Strided/Array.hs b/ops/Data/Array/Strided/Array.hs
index f757cd5..9280fe0 100644
--- a/ops/Data/Array/Strided/Array.hs
+++ b/ops/Data/Array/Strided/Array.hs
@@ -31,7 +31,7 @@ arrayFromVector sh vec
shsize = product sh
strides = NE.tail (NE.scanr (*) 1 sh)
-arrayFromConstant :: (Storable a, KnownNat n) => [Int] -> a -> Array n a
+arrayFromConstant :: Storable a => [Int] -> a -> Array n a
arrayFromConstant sh x = Array sh (0 <$ sh) 0 (VS.singleton x)
arrayRevDims :: [Bool] -> Array n a -> Array n a
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 58aea1b..2d45b08 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -70,7 +70,6 @@ library
ghc-typelits-knownnat,
ghc-typelits-natnormalise,
orthotope < 0.2,
- template-haskell,
vector
hs-source-dirs: src
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 1d076e8..3bd4581 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -493,7 +493,7 @@ rreshape sh' rarr@(Ranked arr)
rflatten :: Elt a => Ranked n a -> Ranked 1 a
rflatten (Ranked arr) = mtoRanked (mflatten arr)
-riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
+riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
-- | Throws if the array is empty.
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index e9e0103..c877412 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -115,21 +115,17 @@ listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
-listrHead ZR = error "unreachable"
listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
-listrTail ZR = error "unreachable"
listrInit :: ListR (n + 1) i -> ListR n i
listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
listrInit (_ ::: ZR) = ZR
-listrInit ZR = error "unreachable"
listrLast :: ListR (n + 1) i -> i
listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
listrLast (n ::: ZR) = n
-listrLast ZR = error "unreachable"
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 109fb70..3bdbac2 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -333,7 +333,7 @@ sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!%
sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim
stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
diff --git a/test/Gen.hs b/test/Gen.hs
index 8099f0d..ae1d1f0 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -98,7 +98,7 @@ genReplicatedShR = \m n -> do
(sh2, sh3) <- injectOnes n sh1 sh1
return (sh1, sh2, sh3)
where
- injectOnes :: m <= n => SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
+ injectOnes :: SNat n -> IShR m -> IShR m -> Gen (IShR n, IShR n)
injectOnes n@SNat shOnes sh
| m@SNat <- shrRank sh
= case cmpNat n m of
@@ -110,7 +110,7 @@ genReplicatedShR = \m n -> do
Refl <- return (lem n m)
injectOnes n (inject index 1 shOnes) (inject index value sh)
- lem :: forall n m proxy. Compare n m ~ GT => proxy n -> proxy m -> (m + 1 <=? n) :~: True
+ lem :: forall n m proxy. proxy n -> proxy m -> (m + 1 <=? n) :~: True
lem _ _ = unsafeCoerceRefl
inject :: Int -> Int -> IShR m -> IShR (m + 1)
diff --git a/test/Util.hs b/test/Util.hs
index 7c06b2f..34cf8ab 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -42,7 +42,7 @@ class AlmostEq f where
almostEq :: (AlmostEqConstr f a, Ord a, Show a, Fractional a, MonadTest m)
=> a -> f a -> f a -> m ()
-instance KnownNat n => AlmostEq (OR.Array n) where
+instance AlmostEq (OR.Array n) where
type AlmostEqConstr (OR.Array n) = OR.Unbox
almostEq atol lhs rhs
| OR.allA (< atol) (OR.zipWithA (\a b -> abs (a - b)) rhs lhs) =