aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
commita65306ba5d80891b20ac86fa3a3242f9497751e6 (patch)
tree834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed
parentd8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff)
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs435
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs55
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists.hs78
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs82
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs47
-rw-r--r--src/Data/Array/Mixed/Permutation.hs252
-rw-r--r--src/Data/Array/Mixed/Shape.hs455
-rw-r--r--src/Data/Array/Mixed/Types.hs110
8 files changed, 1514 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
new file mode 100644
index 0000000..cf6820b
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -0,0 +1,435 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Internal.Arith where
+
+import Control.Monad (forM, guard)
+import qualified Data.Array.Internal as OI
+import qualified Data.Array.Internal.RankedG as RG
+import qualified Data.Array.Internal.RankedS as RS
+import Data.Bits
+import Data.Int
+import Data.List (sort)
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.Ptr
+import Foreign.Storable (Storable)
+import GHC.TypeLits
+import Language.Haskell.TH
+import System.IO.Unsafe
+
+import Data.Array.Mixed.Internal.Arith.Foreign
+import Data.Array.Mixed.Internal.Arith.Lists
+
+
+liftVEltwise1 :: Storable a
+ => SNat n
+ -> (VS.Vector a -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+ | Just prefixSz <- stridesDense sh strides =
+ let vec' = f (VS.slice offset prefixSz vec)
+ in RS.A (RG.A sh (OI.T strides 0 vec'))
+ | otherwise = RS.fromVector sh (f (RS.toVector arr))
+
+liftVEltwise2 :: Storable a
+ => SNat n
+ -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a -> RS.Array n a
+liftVEltwise2 SNat f
+ arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
+ arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
+ | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
+ | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
+ (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
+ let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
+ in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
+ (Just 1, Just n) -> -- scalar * dense
+ RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))
+ (Just n, Just 1) -> -- dense * scalar
+ RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))
+ (_, _) -> -- fallback case
+ RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
+
+-- | Given the shape vector and the stride vector, return whether this vector
+-- of strides uses a dense prefix of its backing array. If so, the number of
+-- elements in this prefix is returned.
+-- This excludes any offset.
+stridesDense :: [Int] -> [Int] -> Maybe Int
+stridesDense sh _ | any (<= 0) sh = Just 0
+stridesDense sh str =
+ -- sort dimensions on their stride, ascending, dropping any zero strides
+ case dropWhile ((== 0) . fst) (sort (zip str sh)) of
+ [] -> Just 1
+ (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
+ where
+ -- Given size of currently densely covered region at beginning of the
+ -- array, the remaining shape vector and the corresponding remaining stride
+ -- vector, return whether this all together covers a dense prefix of the
+ -- array. If it does, return the number of elements in this prefix.
+ checkCover :: Int -> [Int] -> [Int] -> Maybe Int
+ checkCover block [] [] = Just block
+ checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
+ checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
+
+{-# NOINLINE vectorOp1 #-}
+vectorOp1 :: forall a b. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr b -> IO ())
+ -> VS.Vector a -> VS.Vector a
+vectorOp1 ptrconv f v = unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length v)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith v $ \pv ->
+ f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv)
+ VS.unsafeFreeze outv
+
+-- | If two vectors are given, assumes that they have the same length.
+{-# NOINLINE vectorOp2 #-}
+vectorOp2 :: forall a b. Storable a
+ => (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv
+ -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs
+ -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv
+ -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a
+vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
+ (Left x) (Left y) -> VS.singleton (fss x y)
+
+ (Left x) (Right vy) ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vy)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vy $ \pvy ->
+ fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy)
+ VS.unsafeFreeze outv
+
+ (Right vx) (Left y) ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vx)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vx $ \pvx ->
+ fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y)
+ VS.unsafeFreeze outv
+
+ (Right vx) (Right vy)
+ | VS.length vx == VS.length vy ->
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (VS.length vx)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith vx $ \pvx ->
+ VS.unsafeWith vy $ \pvy ->
+ fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy)
+ VS.unsafeFreeze outv
+ | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
+
+-- TODO: test all the weird cases of this function
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> RS.Array (n + 1) a -> RS.Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
+ | null sh = error "unreachable"
+ | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0])
+ | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty))
+ -- now the input array is nonempty
+ | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
+ | last strides == 0 =
+ liftVEltwise1 sn
+ (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
+ (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
+ -- now there is useful work along the inner dimension
+ | otherwise =
+ let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
+ (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
+ ndimsF = length shF
+ in unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product (init shF))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
+ fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
+ RS.fromVector (init sh) <$> VS.unsafeFreeze outv
+
+flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
+ -> Int64 -> Ptr a -> Ptr a -> a -> IO ()
+flipOp f n out v s = f n out s v
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_binary_" ++ atCName arithtype
+ c_ss = varE (aboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ cnamebase = "c_fbinary_" ++ atCName arithtype
+ c_ss = varE (afboNumOp arithop)
+ c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM [minBound..maxBound] $ \arithop -> do
+ let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+-- This branch is ostensibly a runtime branch, but will (hopefully) be
+-- constant-folded away by GHC.
+intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
+ => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
+ -> (SNat n -> RS.Array n i -> RS.Array n i)
+intWidBranch1 f32 f64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> i -> i) -- ss
+ -- int32
+ -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
+ -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
+intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
+intWidBranchRed fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
+ | otherwise = error "Unsupported Int width"
+
+class NumElt a where
+ numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+
+instance NumElt Int32 where
+ numEltAdd = addVectorInt32
+ numEltSub = subVectorInt32
+ numEltMul = mulVectorInt32
+ numEltNeg = negVectorInt32
+ numEltAbs = absVectorInt32
+ numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
+
+instance NumElt Int64 where
+ numEltAdd = addVectorInt64
+ numEltSub = subVectorInt64
+ numEltMul = mulVectorInt64
+ numEltNeg = negVectorInt64
+ numEltAbs = absVectorInt64
+ numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
+
+instance NumElt Float where
+ numEltAdd = addVectorFloat
+ numEltSub = subVectorFloat
+ numEltMul = mulVectorFloat
+ numEltNeg = negVectorFloat
+ numEltAbs = absVectorFloat
+ numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
+
+instance NumElt Double where
+ numEltAdd = addVectorDouble
+ numEltSub = subVectorDouble
+ numEltMul = mulVectorDouble
+ numEltNeg = negVectorDouble
+ numEltAbs = absVectorDouble
+ numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
+
+instance NumElt Int where
+ numEltAdd = intWidBranch2 @Int (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @Int (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @Int (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+instance NumElt CInt where
+ numEltAdd = intWidBranch2 @CInt (+)
+ (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
+ (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ numEltSub = intWidBranch2 @CInt (-)
+ (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
+ (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ numEltMul = intWidBranch2 @CInt (*)
+ (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
+ numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
+ numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
+ numEltSum1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
+ numEltProduct1Inner = intWidBranchRed @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+
+class FloatElt a where
+ floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
+ floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a
+ floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a
+
+instance FloatElt Float where
+ floatEltDiv = divVectorFloat
+ floatEltPow = powVectorFloat
+ floatEltLogbase = logbaseVectorFloat
+ floatEltRecip = recipVectorFloat
+ floatEltExp = expVectorFloat
+ floatEltLog = logVectorFloat
+ floatEltSqrt = sqrtVectorFloat
+ floatEltSin = sinVectorFloat
+ floatEltCos = cosVectorFloat
+ floatEltTan = tanVectorFloat
+ floatEltAsin = asinVectorFloat
+ floatEltAcos = acosVectorFloat
+ floatEltAtan = atanVectorFloat
+ floatEltSinh = sinhVectorFloat
+ floatEltCosh = coshVectorFloat
+ floatEltTanh = tanhVectorFloat
+ floatEltAsinh = asinhVectorFloat
+ floatEltAcosh = acoshVectorFloat
+ floatEltAtanh = atanhVectorFloat
+ floatEltLog1p = log1pVectorFloat
+ floatEltExpm1 = expm1VectorFloat
+ floatEltLog1pexp = log1pexpVectorFloat
+ floatEltLog1mexp = log1mexpVectorFloat
+
+instance FloatElt Double where
+ floatEltDiv = divVectorDouble
+ floatEltPow = powVectorDouble
+ floatEltLogbase = logbaseVectorDouble
+ floatEltRecip = recipVectorDouble
+ floatEltExp = expVectorDouble
+ floatEltLog = logVectorDouble
+ floatEltSqrt = sqrtVectorDouble
+ floatEltSin = sinVectorDouble
+ floatEltCos = cosVectorDouble
+ floatEltTan = tanVectorDouble
+ floatEltAsin = asinVectorDouble
+ floatEltAcos = acosVectorDouble
+ floatEltAtan = atanVectorDouble
+ floatEltSinh = sinhVectorDouble
+ floatEltCosh = coshVectorDouble
+ floatEltTanh = tanhVectorDouble
+ floatEltAsinh = asinhVectorDouble
+ floatEltAcosh = acoshVectorDouble
+ floatEltAtanh = atanhVectorDouble
+ floatEltLog1p = log1pVectorDouble
+ floatEltExpm1 = expm1VectorDouble
+ floatEltLog1pexp = log1pexpVectorDouble
+ floatEltLog1mexp = log1mexpVectorDouble
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
new file mode 100644
index 0000000..6fc7229
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -0,0 +1,55 @@
+{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Mixed.Internal.Arith.Foreign where
+
+import Control.Monad
+import Data.Int
+import Data.Maybe
+import Foreign.C.Types
+import Foreign.Ptr
+import Language.Haskell.TH
+
+import Data.Array.Mixed.Internal.Arith.Lists
+
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "binary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "fbinary_" ++ atCName arithtype
+ sequence $ catMaybes
+ [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ ])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "unary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM floatTypesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "funary_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "reduce_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
new file mode 100644
index 0000000..a284bc1
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Data.Array.Mixed.Internal.Arith.Lists where
+
+import Data.Char
+import Data.Int
+import Language.Haskell.TH
+
+import Data.Array.Mixed.Internal.Arith.Lists.TH
+
+
+data ArithType = ArithType
+ { atType :: Name -- ''Int32
+ , atCName :: String -- "i32"
+ }
+
+floatTypesList :: [ArithType]
+floatTypesList =
+ [ArithType ''Float "float"
+ ,ArithType ''Double "double"
+ ]
+
+typesList :: [ArithType]
+typesList =
+ [ArithType ''Int32 "i32"
+ ,ArithType ''Int64 "i64"
+ ]
+ ++ floatTypesList
+
+-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded)
+$(genArithDataType Binop "ArithBOp")
+
+$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3))
+$(genArithEnumFun Binop ''ArithBOp "aboEnum")
+
+$(do clauses <- readArithLists Binop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |]
+ ,return $ FunD (mkName "aboNumOp") clauses])
+
+
+-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded)
+$(genArithDataType FBinop "ArithFBOp")
+
+$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3))
+$(genArithEnumFun FBinop ''ArithFBOp "afboEnum")
+
+$(do clauses <- readArithLists FBinop
+ (\name _num hsop -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (VarE 'mkName `AppE` LitE (StringL hsop)))
+ []))
+ return
+ sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |]
+ ,return $ FunD (mkName "afboNumOp") clauses])
+
+
+-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded)
+$(genArithDataType Unop "ArithUOp")
+
+$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3))
+$(genArithEnumFun Unop ''ArithUOp "auoEnum")
+
+
+-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded)
+$(genArithDataType FUnop "ArithFUOp")
+
+$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3))
+$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum")
+
+
+-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded)
+$(genArithDataType Redop "ArithRedOp")
+
+$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3))
+$(genArithEnumFun Redop ''ArithRedOp "aroEnum")
diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
new file mode 100644
index 0000000..8b7d05f
--- /dev/null
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
@@ -0,0 +1,82 @@
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Mixed.Internal.Arith.Lists.TH where
+
+import Control.Monad
+import Control.Monad.IO.Class
+import Data.Maybe
+import Foreign.C.Types
+import Language.Haskell.TH
+import Language.Haskell.TH.Syntax
+import Text.Read
+
+
+data OpKind = Binop | FBinop | Unop | FUnop | Redop
+ deriving (Show, Eq)
+
+readArithLists :: OpKind
+ -> (String -> Int -> String -> Q a)
+ -> ([a] -> Q r)
+ -> Q r
+readArithLists targetkind fop fcombine = do
+ addDependentFile "cbits/arith_lists.h"
+ lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h"
+
+ mvals <- forM lns $ \line -> do
+ if null (dropWhile (== ' ') line)
+ then return Nothing
+ else do let (kind, name, num, aux) = parseLine line
+ if kind == targetkind
+ then Just <$> fop name num aux
+ else return Nothing
+
+ fcombine (catMaybes mvals)
+ where
+ parseLine s0
+ | ("LIST_", s1) <- splitAt 5 s0
+ , (kindstr, '(' : s2) <- break (== '(') s1
+ , (f1, ',' : s3) <- parseField s2
+ , (f2, ',' : s4) <- parseField s3
+ , (f3, ')' : _) <- parseField s4
+ , Just kind <- parseKind kindstr
+ , let name = f1
+ , Just num <- readMaybe f2
+ , let aux = f3
+ = (kind, name, num, aux)
+ | otherwise
+ = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0
+
+ parseField s = break (`elem` ",)") (dropWhile (== ' ') s)
+
+ parseKind "BINOP" = Just Binop
+ parseKind "FBINOP" = Just FBinop
+ parseKind "UNOP" = Just Unop
+ parseKind "FUNOP" = Just FUnop
+ parseKind "REDOP" = Just Redop
+ parseKind _ = Nothing
+
+genArithDataType :: OpKind -> String -> Q [Dec]
+genArithDataType kind dtname = do
+ cons <- readArithLists kind
+ (\name _num _ -> return $ NormalC (mkName name) [])
+ return
+ return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]]
+
+genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec]
+genArithNameFun kind dtname funname nametrans = do
+ clauses <- readArithLists kind
+ (\name _num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (StringL (nametrans name))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String)
+ ,FunD (mkName funname) clauses]
+
+genArithEnumFun :: OpKind -> Name -> String -> Q [Dec]
+genArithEnumFun kind dtname funname = do
+ clauses <- readArithLists kind
+ (\name num _ -> return (Clause [ConP (mkName name) [] []]
+ (NormalB (LitE (IntegerL (fromIntegral num))))
+ []))
+ return
+ return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt)
+ ,FunD (mkName funname) clauses]
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
new file mode 100644
index 0000000..30ec9c0
--- /dev/null
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -0,0 +1,47 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE DataKinds #-}
+module Data.Array.Mixed.Lemmas where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
+ lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
+ lemRankApp ssh1 ssh2
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> c :~: b + a
+ -> b + a :~: c
+ lem _ _ _ Refl = Refl
+
+ lem2 :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem2 _ _ _ Refl = Refl
+
+lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
+
+lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKX = Dict
+lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
new file mode 100644
index 0000000..2710018
--- /dev/null
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -0,0 +1,252 @@
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Permutation where
+
+import Data.Coerce (coerce)
+import Data.Functor.Const
+import Data.List (sort)
+import Data.Proxy
+import Data.Type.Bool
+import Data.Type.Equality
+import Data.Type.Ord
+import GHC.TypeError
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+
+
+-- * Permutations
+
+-- | A "backward" permutation of a dimension list. The operation on the
+-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
+-- for code that implements this.
+data Perm list where
+ PNil :: Perm '[]
+ PCons :: SNat a -> Perm l -> Perm (a : l)
+infixr 5 `PCons`
+deriving instance Show (Perm list)
+deriving instance Eq (Perm list)
+
+permLengthSNat :: Perm list -> SNat (Rank list)
+permLengthSNat PNil = SNat
+permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat
+
+permFromList :: [Int] -> (forall list. Perm list -> r) -> r
+permFromList [] k = k PNil
+permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+
+permToList :: Perm list -> [Natural]
+permToList PNil = mempty
+permToList (x `PCons` l) = TN.fromSNat x : permToList l
+
+permToList' :: Perm list -> [Int]
+permToList' = map fromIntegral . permToList
+
+
+-- ** Applying permutations
+
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
+listxTakeLen PNil _ = ZX
+listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
+
+listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
+listxDropLen PNil sh = sh
+listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
+
+listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
+listxPermute PNil _ = ZX
+listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
+ listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
+listxIndex _ _ SZ (n ::% _) rest = n ::% rest
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listxIndex p pT i sh rest
+listxIndex _ _ _ ZX _ = error "Index into empty shape"
+
+listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
+
+ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+
+ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+
+ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+
+ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+
+shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
+shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+
+
+-- * Operations on permutations
+
+-- TODO: test this thing more properly
+permInverse :: Perm is
+ -> (forall is'.
+ IsPermutation is'
+ => Perm is'
+ -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
+ -> r)
+ -> r
+permInverse = \perm k ->
+ genPerm perm $ \(invperm :: Perm is') ->
+ let sn = permLengthSNat invperm
+ in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
+ (Just Refl, Just Refl) ->
+ k invperm
+ (\ssh -> case provePermInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm
+ where
+ genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
+ genPerm perm =
+ let permList = permToList' perm
+ in toHList $ map snd (sort (zip permList [0..]))
+ where
+ toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
+ toHList [] k = k PNil
+ toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
+
+ lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount _ _ = unsafeCoerceRefl
+
+ lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount _ _ = unsafeCoerceRefl
+
+ lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
+ lemElem _ _ = unsafeCoerceRefl
+
+ provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
+ -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ provePerm1 _ _ PNil = Just (Refl)
+ provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ | Just Refl <- provePerm1 p rtop perm
+ = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
+ (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ _ -> Nothing
+ | otherwise
+ = Nothing
+
+ provePerm2 :: SNat i -> SNat n -> Perm is'
+ -> Maybe (AllElem' (Count i n) is' :~: True)
+ provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
+ case cmpNat i n of
+ EQI -> Just Refl
+ LTI | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ -> checkElem i perm
+ | otherwise -> Nothing
+ GTI -> error "unreachable"
+ where
+ checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
+ checkElem _ PNil = Nothing
+ checkElem i@SNat (PCons k@SNat perm :: Perm is') =
+ case sameNat i k of
+ Just Refl -> Just Refl
+ Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
+ | otherwise -> Nothing
+
+ provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ -> Maybe (Permute is' (Permute is sh) :~: sh)
+ provePermInverse perm perminv ssh =
+ ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+
+type family MapSucc is where
+ MapSucc '[] = '[]
+ MapSucc (i : is) = i + 1 : MapSucc is
+
+permShift1 :: Perm l -> Perm (0 : MapSucc l)
+permShift1 = (SNat @0 `PCons`) . permMapSucc
+ where
+ permMapSucc :: Perm l -> Perm (MapSucc l)
+ permMapSucc PNil = PNil
+ permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+
+
+-- * Lemmas
+
+lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ PNil = Refl
+lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKX PNil = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% _) PNil = Refl
+lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
+ -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
new file mode 100644
index 0000000..a16da76
--- /dev/null
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -0,0 +1,455 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Shape where
+
+import Control.DeepSeq (NFData(..))
+import qualified Data.Foldable as Foldable
+import Data.Functor.Const
+import Data.Kind (Type, Constraint)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import qualified GHC.IsList as IsList
+import GHC.TypeLits
+
+import Data.Array.Mixed.Types
+import Data.Coerce
+import Data.Bifunctor (first)
+
+
+-- | The length of a type-level list. If the argument is a shape, then the
+-- result is the rank of that shape.
+type family Rank sh where
+ Rank '[] = 0
+ Rank (_ : sh) = Rank sh + 1
+
+
+-- * Mixed lists
+
+type role ListX nominal representational
+type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
+data ListX sh f where
+ ZX :: ListX '[] f
+ (::%) :: f n -> ListX sh f -> ListX (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
+infixr 3 ::%
+
+instance (forall n. Show (f n)) => Show (ListX sh f) where
+ showsPrec _ = listxShow shows
+
+instance (forall n. NFData (f n)) => NFData (ListX sh f) where
+ rnf ZX = ()
+ rnf (x ::% l) = rnf x `seq` rnf l
+
+data UnconsListXRes f sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
+listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
+listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
+listxUncons ZX = Nothing
+
+listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
+listxFmap _ ZX = ZX
+listxFmap f (x ::% xs) = f x ::% listxFmap f xs
+
+listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+listxFold _ ZX = mempty
+listxFold f (x ::% xs) = f x <> listxFold f xs
+
+listxLength :: ListX sh f -> Int
+listxLength = getSum . listxFold (\_ -> Sum 1)
+
+listxLengthSNat :: ListX sh f -> SNat (Rank sh)
+listxLengthSNat ZX = SNat
+listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat
+
+listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
+listxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListX sh' f -> ShowS
+ go _ ZX = id
+ go prefix (x ::% xs) = showString prefix . f x . go "," xs
+
+listxToList :: ListX sh' (Const i) -> [i]
+listxToList ZX = []
+listxToList (Const i ::% is) = i : listxToList is
+
+listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend ZX idx' = idx'
+listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
+
+listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
+listxDrop long ZX = long
+listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+
+
+-- * Mixed indices
+
+-- | This is a newtype over 'ListX'.
+type role IxX nominal representational
+type IxX :: [Maybe Nat] -> Type -> Type
+newtype IxX sh i = IxX (ListX sh (Const i))
+ deriving (Eq, Ord, Generic)
+
+pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
+pattern ZIX = IxX ZX
+
+pattern (:.%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxX sh i -> IxX sh1 i
+pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
+ where i :.% IxX shl = IxX (Const i ::% shl)
+infixr 3 :.%
+
+{-# COMPLETE ZIX, (:.%) #-}
+
+type IIxX sh = IxX sh Int
+
+instance Show i => Show (IxX sh i) where
+ showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l
+
+instance Functor (IxX sh) where
+ fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
+
+instance Foldable (IxX sh) where
+ foldMap f (IxX l) = listxFold (f . getConst) l
+
+instance NFData i => NFData (IxX sh i)
+
+ixxZero :: StaticShX sh -> IIxX sh
+ixxZero ZKX = ZIX
+ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
+
+ixxZero' :: IShX sh -> IIxX sh
+ixxZero' ZSX = ZIX
+ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
+
+ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
+ixxAppend = coerce (listxAppend @_ @(Const i))
+
+ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop = coerce (listxDrop @(Const i) @(Const i))
+
+ixxFromLinear :: IShX sh -> Int -> IIxX sh
+ixxFromLinear = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IShX sh -> Int -> (IIxX sh, Int)
+ go ZSX i = (ZIX, i)
+ go (n :$% sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` fromSMayNat' n
+ in (locali :.% idx, upi)
+
+ixxToLinear :: IShX sh -> IIxX sh -> Int
+ixxToLinear = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IShX sh -> IIxX sh -> (Int, Int)
+ go ZSX ZIX = (0, 1)
+ go (n :$% sh) (i :.% ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, fromSMayNat' n * sz)
+
+
+-- * Mixed shapes
+
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
+deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
+deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+
+instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+ rnf (SUnknown i) = rnf i
+ rnf (SKnown x) = rnf x
+
+fromSMayNat :: (n ~ Nothing => i -> r)
+ -> (forall m. n ~ Just m => f m -> r)
+ -> SMayNat i f n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
+type family AddMaybe n m where
+ AddMaybe Nothing _ = Nothing
+ AddMaybe (Just _) Nothing = Nothing
+ AddMaybe (Just n) (Just m) = Just (n + m)
+
+smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+
+
+-- | This is a newtype over 'ListX'.
+type role ShX nominal representational
+type ShX :: [Maybe Nat] -> Type -> Type
+newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+ deriving (Eq, Ord, Generic)
+
+pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
+pattern ZSX = ShX ZX
+
+pattern (:$%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
+ where i :$% ShX shl = ShX (i ::% shl)
+infixr 3 :$%
+
+{-# COMPLETE ZSX, (:$%) #-}
+
+type IShX sh = ShX sh Int
+
+instance Show i => Show (ShX sh i) where
+ showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+
+instance Functor (ShX sh) where
+ fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
+
+instance NFData i => NFData (ShX sh i) where
+ rnf (ShX ZX) = ()
+ rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
+ rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+
+shxLength :: ShX sh i -> Int
+shxLength (ShX l) = listxLength l
+
+-- | This is more than @geq@: it also checks that the integers (the unknown
+-- dimensions) are the same.
+shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqual ZSX ZSX = Just Refl
+shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
+ | i == j
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual _ _ = Nothing
+
+-- | The number of elements in an array described by this shape.
+shxSize :: IShX sh -> Int
+shxSize ZSX = 1
+shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
+
+shxToList :: IShX sh -> [Int]
+shxToList ZSX = []
+shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+
+shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+
+shxTail :: ShX (n : sh) i -> ShX sh i
+shxTail (_ :$% sh) = sh
+
+shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+
+shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+
+shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+
+shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
+shxTakeSSX _ = flip go
+ where
+ go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
+ go ZKX _ = ZSX
+ go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+
+-- This is a weird operation, so it has a long name
+shxCompleteZeros :: StaticShX sh -> IShX sh
+shxCompleteZeros ZKX = ZSX
+shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
+shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
+
+shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp _ ZKX idx = (ZSX, idx)
+shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
+
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = \sh -> go sh id []
+ where
+ go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
+ go ZSX f = (f ZIX :)
+ go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+
+
+-- * Static mixed shapes
+
+-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+type StaticShX :: [Maybe Nat] -> Type
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+ deriving (Eq, Ord)
+
+pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
+pattern ZKX = StaticShX ZX
+
+pattern (:!%)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = StaticShX (i ::% shl)
+infixr 3 :!%
+
+{-# COMPLETE ZKX, (:!%) #-}
+
+instance Show (StaticShX sh) where
+ showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+
+ssxLength :: StaticShX sh -> Int
+ssxLength (StaticShX l) = listxLength l
+
+-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
+-- class of the @some@ package.
+ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+ssxGeq ZKX ZKX = Just Refl
+ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- ssxGeq sh sh'
+ = Just Refl
+ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
+ | Just Refl <- ssxGeq sh sh'
+ = Just Refl
+ssxGeq _ _ = Nothing
+
+ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
+ssxAppend ZKX sh' = sh'
+ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+
+ssxTail :: StaticShX (n : sh) -> StaticShX sh
+ssxTail (_ :!% ssh) = ssh
+
+ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+
+-- | This may fail if @sh@ has @Nothing@s in it.
+ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
+ssxToShX' ZKX = Just ZSX
+ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
+ssxToShX' (SUnknown _ :!% _) = Nothing
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = SUnknown () :!% ssxReplicate n
+
+ssxIotaFrom :: Int -> StaticShX sh -> [Int]
+ssxIotaFrom _ ZKX = []
+ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+
+ssxFromShape :: IShX sh -> StaticShX sh
+ssxFromShape ZSX = ZKX
+ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShX :: [Maybe Nat] -> Constraint
+class KnownShX sh where knownShX :: StaticShX sh
+instance KnownShX '[] where knownShX = ZKX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+
+
+-- * Flattening
+
+type Flatten sh = Flatten' 1 sh
+
+type family Flatten' acc sh where
+ Flatten' acc '[] = Just acc
+ Flatten' acc (Nothing : sh) = Nothing
+ Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
+
+-- This function is currently unused
+ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go acc ZKX = SKnown acc
+ go _ (SUnknown () :!% _) = SUnknown ()
+ go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
+
+shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go acc ZSX = SKnown acc
+ go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
+ go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
+
+ goUnknown :: Int -> IShX sh -> Int
+ goUnknown acc ZSX = acc
+ goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
+ goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
+
+
+-- | Very untyped: only length is checked (at runtime).
+instance KnownShX sh => IsList (ListX sh (Const i)) where
+ type Item (ListX sh (Const i)) = i
+ fromList topl = go (knownShX @sh) topl
+ where
+ go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
+ go ZKX [] = ZX
+ go (_ :!% sh) (i : is) = Const i ::% go sh is
+ go _ _ = error $ "IsList(ListX): Mismatched list length (type says "
+ ++ show (ssxLength (knownShX @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listxToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShX sh => IsList (IxX sh i) where
+ type Item (IxX sh i) = i
+ fromList = IxX . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and known dimensions are checked (at runtime).
+instance KnownShX sh => IsList (ShX sh Int) where
+ type Item (ShX sh Int) = Int
+ fromList topl = ShX (go (knownShX @sh) topl)
+ where
+ go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat)
+ go ZKX [] = ZX
+ go (SKnown sn :!% sh) (i : is)
+ | i == fromSNat' sn = SKnown sn ::% go sh is
+ | otherwise = error $ "IsList(ShX): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is
+ go _ _ = error $ "IsList(ShX): Mismatched list length (type says "
+ ++ show (ssxLength (knownShX @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shxToList
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
new file mode 100644
index 0000000..d77513f
--- /dev/null
+++ b/src/Data/Array/Mixed/Types.hs
@@ -0,0 +1,110 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Types (
+ -- * Reified evidence of a type class
+ Dict(..),
+
+ -- * Type-level naturals
+ pattern SZ, pattern SS,
+ fromSNat',
+ snatPlus, snatMul,
+
+ -- * Type-level lists
+ type (++),
+ lemAppNil,
+ lemAppAssoc,
+ Replicate,
+ lemReplicateSucc,
+
+ -- * Unsafe
+ unsafeCoerceRefl,
+) where
+
+import Data.Type.Equality
+import Data.Proxy
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+import qualified Unsafe.Coerce
+
+
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
+
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
+-- This should be a function in base
+snatPlus :: SNat n -> SNat m -> SNat (n + m)
+snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+
+-- This should be a function in base
+snatMul :: SNat n -> SNat m -> SNat (n * m)
+snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+
+
+-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
+-- only typecheck for actual type equalities. One cannot, e.g. accidentally
+-- write this:
+--
+-- @
+-- foo :: Proxy a -> Proxy b -> a :~: b
+-- foo = unsafeCoerceRefl
+-- @
+--
+-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce',
+-- but would have resulted in interesting memory errors at runtime.
+unsafeCoerceRefl :: a :~: b
+unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl
+
+
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerceRefl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerceRefl
+
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerceRefl