diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 | 
| commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
| tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed | |
| parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) | |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 435 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 55 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists.hs | 78 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 82 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 47 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 252 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 455 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Types.hs | 110 | 
8 files changed, 1514 insertions, 0 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs new file mode 100644 index 0000000..cf6820b --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -0,0 +1,435 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Mixed.Internal.Arith.Foreign +import Data.Array.Mixed.Internal.Arith.Lists + + +liftVEltwise1 :: Storable a +              => SNat n +              -> (VS.Vector a -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +  | Just prefixSz <- stridesDense sh strides = +      let vec' = f (VS.slice offset prefixSz vec) +      in RS.A (RG.A sh (OI.T strides 0 vec')) +  | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +liftVEltwise2 :: Storable a +              => SNat n +              -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) +              -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f +    arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) +    arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) +  | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 +  | product sh1 == 0 = arr1  -- if the arrays are empty, just return one of the empty inputs +  | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of +      (Just 1, Just 1) ->  -- both are a (potentially replicated) scalar; just apply f to the scalars +        let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) +        in RS.A (RG.A sh1 (OI.T strides1 0 vec')) +      (Just 1, Just n) ->  -- scalar * dense +        RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) +      (Just n, Just 1) ->  -- dense * scalar +        RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) +      (_, _) ->  -- fallback case +        RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = +  -- sort dimensions on their stride, ascending, dropping any zero strides +  case dropWhile ((== 0) . fst) (sort (zip str sh)) of +    [] -> Just 1 +    (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' +    _ -> Nothing  -- if the smallest stride is not 1, it will never be dense +  where +    -- Given size of currently densely covered region at beginning of the +    -- array, the remaining shape vector and the corresponding remaining stride +    -- vector, return whether this all together covers a dense prefix of the +    -- array. If it does, return the number of elements in this prefix. +    checkCover :: Int -> [Int] -> [Int] -> Maybe Int +    checkCover block [] [] = Just block +    checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' +    checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +{-# NOINLINE vectorOp1 #-} +vectorOp1 :: forall a b. Storable a +          => (Ptr a -> Ptr b) +          -> (Int64 -> Ptr b -> Ptr b -> IO ()) +          -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do +  outv <- VSM.unsafeNew (VS.length v) +  VSM.unsafeWith outv $ \poutv -> +    VS.unsafeWith v $ \pv -> +      f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) +  VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} +vectorOp2 :: forall a b. Storable a +          => (a -> b) +          -> (Ptr a -> Ptr b) +          -> (a -> a -> a) +          -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- sv +          -> (Int64 -> Ptr b -> Ptr b -> b -> IO ())  -- vs +          -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ())  -- vv +          -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases +  (Left x) (Left y) -> VS.singleton (fss x y) + +  (Left x) (Right vy) -> +    unsafePerformIO $ do +      outv <- VSM.unsafeNew (VS.length vy) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith vy $ \pvy -> +          fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) +      VS.unsafeFreeze outv + +  (Right vx) (Left y) -> +    unsafePerformIO $ do +      outv <- VSM.unsafeNew (VS.length vx) +      VSM.unsafeWith outv $ \poutv -> +        VS.unsafeWith vx $ \pvx -> +          fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) +      VS.unsafeFreeze outv + +  (Right vx) (Right vy) +    | VS.length vx == VS.length vy -> +        unsafePerformIO $ do +          outv <- VSM.unsafeNew (VS.length vx) +          VSM.unsafeWith outv $ \poutv -> +            VS.unsafeWith vx $ \pvx -> +              VS.unsafeWith vy $ \pvy -> +                fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) +          VS.unsafeFreeze outv +    | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) +                 => SNat n +                 -> (a -> b) +                 -> (Ptr a -> Ptr b) +                 -> (Int64 -> Ptr b -> b -> Ptr b -> IO ())  -- ^ scale by constant +                 -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                 -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) +  | null sh = error "unreachable" +  | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) +  | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) +  -- now the input array is nonempty +  | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) +  | last strides == 0 = +      liftVEltwise1 sn +        (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) +        (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) +  -- now there is useful work along the inner dimension +  | otherwise = +      let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those +          (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) +          ndimsF = length shF +      in unsafePerformIO $ do +           outv <- VSM.unsafeNew (product (init shF)) +           VSM.unsafeWith outv $ \poutv -> +             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> +               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> +                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> +                   fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) +           RS.fromVector (init sh) <$> VS.unsafeFreeze outv + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) +       ->  Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_binary_" ++ atCName arithtype +          c_ss = varE (aboNumOp arithop) +          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          cnamebase = "c_fbinary_" ++ atCName arithtype +          c_ss = varE (afboNumOp arithop) +          c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +          c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    fmap concat . forM [minBound..maxBound] $ \arithop -> do +      let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) +          c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) +          c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) +      sequence [SigD name <$> +                     [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] +               ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] +                   return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) +              => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) +              -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) +  | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) +              => (i -> i -> i)  -- ss +                 -- int32 +              -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- sv +              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ())  -- vs +              -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- vv +                 -- int64 +              -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- sv +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ())  -- vs +              -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- vv +              -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn +  | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) +  | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) +  | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) +                => -- int32 +                   (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                   -- int64 +                -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ())  -- ^ scale by constant +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel +                -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn +  | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 +  | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 +  | otherwise = error "Unsupported Int width" + +class NumElt a where +  numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a +  numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a +  numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a +  numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + +instance NumElt Int32 where +  numEltAdd = addVectorInt32 +  numEltSub = subVectorInt32 +  numEltMul = mulVectorInt32 +  numEltNeg = negVectorInt32 +  numEltAbs = absVectorInt32 +  numEltSignum = signumVectorInt32 +  numEltSum1Inner = sum1VectorInt32 +  numEltProduct1Inner = product1VectorInt32 + +instance NumElt Int64 where +  numEltAdd = addVectorInt64 +  numEltSub = subVectorInt64 +  numEltMul = mulVectorInt64 +  numEltNeg = negVectorInt64 +  numEltAbs = absVectorInt64 +  numEltSignum = signumVectorInt64 +  numEltSum1Inner = sum1VectorInt64 +  numEltProduct1Inner = product1VectorInt64 + +instance NumElt Float where +  numEltAdd = addVectorFloat +  numEltSub = subVectorFloat +  numEltMul = mulVectorFloat +  numEltNeg = negVectorFloat +  numEltAbs = absVectorFloat +  numEltSignum = signumVectorFloat +  numEltSum1Inner = sum1VectorFloat +  numEltProduct1Inner = product1VectorFloat + +instance NumElt Double where +  numEltAdd = addVectorDouble +  numEltSub = subVectorDouble +  numEltMul = mulVectorDouble +  numEltNeg = negVectorDouble +  numEltAbs = absVectorDouble +  numEltSignum = signumVectorDouble +  numEltSum1Inner = sum1VectorDouble +  numEltProduct1Inner = product1VectorDouble + +instance NumElt Int where +  numEltAdd = intWidBranch2 @Int (+) +                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) +                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @Int (-) +                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) +                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @Int (*) +                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) +                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed @Int +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) +  numEltProduct1Inner = intWidBranchRed @Int +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +instance NumElt CInt where +  numEltAdd = intWidBranch2 @CInt (+) +                (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) +                (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) +  numEltSub = intWidBranch2 @CInt (-) +                (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) +                (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) +  numEltMul = intWidBranch2 @CInt (*) +                (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) +                (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) +  numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) +  numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) +  numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) +  numEltSum1Inner = intWidBranchRed @CInt +                      (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) +                      (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) +  numEltProduct1Inner = intWidBranchRed @CInt +                          (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) +                          (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where +  floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a +  floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a +  floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where +  floatEltDiv = divVectorFloat +  floatEltPow = powVectorFloat +  floatEltLogbase = logbaseVectorFloat +  floatEltRecip = recipVectorFloat +  floatEltExp = expVectorFloat +  floatEltLog = logVectorFloat +  floatEltSqrt = sqrtVectorFloat +  floatEltSin = sinVectorFloat +  floatEltCos = cosVectorFloat +  floatEltTan = tanVectorFloat +  floatEltAsin = asinVectorFloat +  floatEltAcos = acosVectorFloat +  floatEltAtan = atanVectorFloat +  floatEltSinh = sinhVectorFloat +  floatEltCosh = coshVectorFloat +  floatEltTanh = tanhVectorFloat +  floatEltAsinh = asinhVectorFloat +  floatEltAcosh = acoshVectorFloat +  floatEltAtanh = atanhVectorFloat +  floatEltLog1p = log1pVectorFloat +  floatEltExpm1 = expm1VectorFloat +  floatEltLog1pexp = log1pexpVectorFloat +  floatEltLog1mexp = log1mexpVectorFloat + +instance FloatElt Double where +  floatEltDiv = divVectorDouble +  floatEltPow = powVectorDouble +  floatEltLogbase = logbaseVectorDouble +  floatEltRecip = recipVectorDouble +  floatEltExp = expVectorDouble +  floatEltLog = logVectorDouble +  floatEltSqrt = sqrtVectorDouble +  floatEltSin = sinVectorDouble +  floatEltCos = cosVectorDouble +  floatEltTan = tanVectorDouble +  floatEltAsin = asinVectorDouble +  floatEltAcos = acosVectorDouble +  floatEltAtan = atanVectorDouble +  floatEltSinh = sinhVectorDouble +  floatEltCosh = coshVectorDouble +  floatEltTanh = tanhVectorDouble +  floatEltAsinh = asinhVectorDouble +  floatEltAcosh = acoshVectorDouble +  floatEltAtanh = atanhVectorDouble +  floatEltLog1p = log1pVectorDouble +  floatEltExpm1 = expm1VectorDouble +  floatEltLog1pexp = log1pexpVectorDouble +  floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..6fc7229 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "binary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "fbinary_" ++ atCName arithtype +    sequence $ catMaybes +      [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) +      ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> +               [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) +      ]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "unary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "funary_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "reduce_" ++ atCName arithtype +    pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +      [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs new file mode 100644 index 0000000..a284bc1 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists.TH + + +data ArithType = ArithType +  { atType :: Name  -- ''Int32 +  , atCName :: String  -- "i32" +  } + +floatTypesList :: [ArithType] +floatTypesList = +  [ArithType ''Float "float" +  ,ArithType ''Double "double" +  ] + +typesList :: [ArithType] +typesList = +  [ArithType ''Int32 "i32" +  ,ArithType ''Int64 "i64" +  ] +  ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] +              ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop +                  (\name _num hsop -> return (Clause [ConP (mkName name) [] []] +                                                     (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) +                                                     [])) +                  return +     sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] +              ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..8b7d05f --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Mixed.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | FBinop | Unop | FUnop | Redop +  deriving (Show, Eq) + +readArithLists :: OpKind +               -> (String -> Int -> String -> Q a) +               -> ([a] -> Q r) +               -> Q r +readArithLists targetkind fop fcombine = do +  addDependentFile "cbits/arith_lists.h" +  lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + +  mvals <- forM lns $ \line -> do +    if null (dropWhile (== ' ') line) +      then return Nothing +      else do let (kind, name, num, aux) = parseLine line +              if kind == targetkind +                then Just <$> fop name num aux +                else return Nothing + +  fcombine (catMaybes mvals) +  where +    parseLine s0 +      | ("LIST_", s1) <- splitAt 5 s0 +      , (kindstr, '(' : s2) <- break (== '(') s1 +      , (f1, ',' : s3) <- parseField s2 +      , (f2, ',' : s4) <- parseField s3 +      , (f3, ')' : _) <- parseField s4 +      , Just kind <- parseKind kindstr +      , let name = f1 +      , Just num <- readMaybe f2 +      , let aux = f3 +      = (kind, name, num, aux) +      | otherwise +      = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + +    parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + +    parseKind "BINOP" = Just Binop +    parseKind "FBINOP" = Just FBinop +    parseKind "UNOP" = Just Unop +    parseKind "FUNOP" = Just FUnop +    parseKind "REDOP" = Just Redop +    parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do +  cons <- readArithLists kind +            (\name _num _ -> return $ NormalC (mkName name) []) +            return +  return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do +  clauses <- readArithLists kind +               (\name _num _ -> return (Clause [ConP (mkName name) [] []] +                                               (NormalB (LitE (StringL (nametrans name)))) +                                               [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) +         ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do +  clauses <- readArithLists kind +               (\name num _ -> return (Clause [ConP (mkName name) [] []] +                                              (NormalB (LitE (IntegerL (fromIntegral num)))) +                                              [])) +               return +  return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) +         ,FunD (mkName funname) clauses] diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs new file mode 100644 index 0000000..30ec9c0 --- /dev/null +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DataKinds #-} +module Data.Array.Mixed.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +lemRankApp :: forall sh1 sh2. +              StaticShX sh1 -> StaticShX sh2 +           -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 +  = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $ +      lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $ +        lemRankApp ssh1 ssh2 +  where +    lem :: proxy a -> proxy b -> proxy c +        -> c :~: b + a +        -> b + a :~: c +    lem _ _ _ Refl = Refl + +    lem2 :: proxy a -> proxy b -> proxy c +         -> (a + b :~: c) +         -> c + 1 :~: (a + 1 + b) +    lem2 _ _ _ Refl = Refl + +lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 +               -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl  -- TODO improve this + +lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs new file mode 100644 index 0000000..2710018 --- /dev/null +++ b/src/Data/Array/Mixed/Permutation.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.TypeError +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' +-- for code that implements this. +data Perm list where +  PNil :: Perm '[] +  PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +permLengthSNat :: Perm list -> SNat (Rank list) +permLengthSNat PNil = SNat +permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case +  Just sn -> permFromList xs $ \list -> k (sn `PCons` list) +  Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + + +-- ** Applying permutations + +type family Elem x l where +  Elem x '[] = 'False +  Elem x (x : _) = 'True +  Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where +  AllElem' '[] bs = 'True +  AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) +  (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where +  Count n n = '[] +  Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where +  Index 0 (n : sh) = n +  Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where +  Permute '[] sh = '[] +  Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where +  TakeLen '[] l = '[] +  TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where +  DropLen '[] l = l +  DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = +  listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f +listxIndex _ _ SZ (n ::% _) rest = n ::% rest +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest +  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') +  = listxIndex p pT i sh rest +listxIndex _ _ _ ZX _ = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +-- TODO: test this thing more properly +permInverse :: Perm is +            -> (forall is'. +                     IsPermutation is' +                  => Perm is' +                  -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) +                  -> r) +            -> r +permInverse = \perm k -> +  genPerm perm $ \(invperm :: Perm is') -> +    let sn = permLengthSNat invperm +    in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of +         (Just Refl, Just Refl) -> +           k invperm +             (\ssh -> case provePermInverse perm invperm ssh of +                        Just eq -> eq +                        Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm +                                           ++ " ; invperm = " ++ show invperm) +         _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm +                      ++ " ; invperm = " ++ show invperm +  where +    genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r +    genPerm perm = +      let permList = permToList' perm +      in toHList $ map snd (sort (zip permList [0..])) +      where +        toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r +        toHList [] k = k PNil +        toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + +    lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True +    lemElemCount _ _ = unsafeCoerceRefl + +    lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n +    lemCount _ _ = unsafeCoerceRefl + +    lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True +    lemElem _ _ = unsafeCoerceRefl + +    provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' +               -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) +    provePerm1 _ _ PNil = Just (Refl) +    provePerm1 p rtop@SNat (PCons sn@SNat perm) +      | Just Refl <- provePerm1 p rtop perm +      = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of +          (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +          (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl +          _ -> Nothing +      | otherwise +      = Nothing + +    provePerm2 :: SNat i -> SNat n -> Perm is' +               -> Maybe (AllElem' (Count i n) is' :~: True) +    provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> +      case cmpNat i n of +        EQI -> Just Refl +        LTI | Refl <- lemCount i n +            , Just Refl <- provePerm2 (SNat @(i + 1)) n perm +            -> checkElem i perm +            | otherwise -> Nothing +        GTI -> error "unreachable" +      where +        checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) +        checkElem _ PNil = Nothing +        checkElem i@SNat (PCons k@SNat perm :: Perm is') = +          case sameNat i k of +            Just Refl -> Just Refl +            Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl +                    | otherwise -> Nothing + +    provePermInverse :: Perm is -> Perm is' -> StaticShX sh +                     -> Maybe (Permute is' (Permute is sh) :~: sh) +    provePermInverse perm perminv ssh = +      ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where +  MapSucc '[] = '[] +  MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc +  where +    permMapSucc :: Perm l -> Perm (MapSucc l) +    permMapSucc PNil = PNil +    permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) +               => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l +             -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs new file mode 100644 index 0000000..a16da76 --- /dev/null +++ b/src/Data/Array/Mixed/Shape.hs @@ -0,0 +1,455 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import qualified Data.Foldable as Foldable +import Data.Functor.Const +import Data.Kind (Type, Constraint) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import qualified GHC.IsList as IsList +import GHC.TypeLits + +import Data.Array.Mixed.Types +import Data.Coerce +import Data.Bifunctor (first) + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where +  Rank '[] = 0 +  Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where +  ZX :: ListX '[] f +  (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +instance (forall n. Show (f n)) => Show (ListX sh f) where +  showsPrec _ = listxShow shows + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where +  rnf ZX = () +  rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = +  forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxLengthSNat :: ListX sh f -> SNat (Rank sh) +listxLengthSNat ZX = SNat +listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" +  where +    go :: String -> ListX sh' f -> ShowS +    go _ ZX = id +    go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short + + +-- * Mixed indices + +-- | This is a newtype over 'ListX'. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) +  deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) +  :: forall {sh1} {i}. +     forall n sh. (n : sh ~ sh1) +  => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) +  where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +type IIxX sh = IxX sh Int + +instance Show i => Show (IxX sh i) where +  showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l + +instance Functor (IxX sh) where +  fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where +  foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of +  (idx, 0) -> idx +  _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ +               " in array of shape " ++ show sh ++ ")" +  where +    -- returns (index in subarray, remaining index in enclosing array) +    go :: IShX sh -> Int -> (IIxX sh, Int) +    go ZSX i = (ZIX, i) +    go (n :$% sh) i = +      let (idx, i') = go sh i +          (upi, locali) = i' `quotRem` fromSMayNat' n +      in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) +  where +    -- returns (index in subarray, size of subarray) +    go :: IShX sh -> IIxX sh -> (Int, Int) +    go ZSX ZIX = (0, 1) +    go (n :$% sh) (i :.% ix) = +      let (lidx, sz) = go sh ix +      in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where +  SUnknown :: i -> SMayNat i f Nothing +  SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +  rnf (SUnknown i) = rnf i +  rnf (SKnown x) = rnf x + +fromSMayNat :: (n ~ Nothing => i -> r) +            -> (forall m. n ~ Just m => f m -> r) +            -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where +  AddMaybe Nothing _ = Nothing +  AddMaybe (Just _) Nothing = Nothing +  AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) +  deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) +  :: forall {sh1} {i}. +     forall n sh. (n : sh ~ sh1) +  => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) +  where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +instance Show i => Show (ShX sh i) where +  showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +instance Functor (ShX sh) where +  fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where +  rnf (ShX ZX) = () +  rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) +  rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +-- | This is more than @geq@: it also checks that the integers (the unknown +-- dimensions) are the same. +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') +  | Just Refl <- sameNat n m +  , Just Refl <- shxEqual sh sh' +  = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') +  | i == j +  , Just Refl <- shxEqual sh sh' +  = Just Refl +shxEqual _ _ = Nothing + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ :$% sh) = sh + +shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i +shxTakeSSX _ = flip go +  where +    go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i +    go ZKX _ = ZSX +    go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] +  where +    go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] +    go ZSX f = (f ZIX :) +    go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) +  deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) +  :: forall {sh1}. +     forall n sh. (n : sh ~ sh1) +  => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) +  where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +instance Show (StaticShX sh) where +  showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ +-- class of the @some@ package. +ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxGeq ZKX ZKX = Just Refl +ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') +  | Just Refl <- sameNat n m +  , Just Refl <- ssxGeq sh sh' +  = Just Refl +ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') +  | Just Refl <- ssxGeq sh sh' +  = Just Refl +ssxGeq _ _ = Nothing + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShX' :: StaticShX sh -> Maybe (IShX sh) +ssxToShX' ZKX = Just ZSX +ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh +ssxToShX' (SUnknown _ :!% _) = Nothing + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n' +  = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh + +ssxFromShape :: IShX sh -> StaticShX sh +ssxFromShape ZSX = ZKX +ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where +  Flatten' acc '[] = Just acc +  Flatten' acc (Nothing : sh) = Nothing +  Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) +  where +    go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) +    go acc ZKX = SKnown acc +    go _ (SUnknown () :!% _) = SUnknown () +    go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) +  where +    go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) +    go acc ZSX = SKnown acc +    go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) +    go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + +    goUnknown :: Int -> IShX sh -> Int +    goUnknown acc ZSX = acc +    goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh +    goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where +  type Item (ListX sh (Const i)) = i +  fromList topl = go (knownShX @sh) topl +    where +      go :: StaticShX sh' -> [i] -> ListX sh' (Const i) +      go ZKX [] = ZX +      go (_ :!% sh) (i : is) = Const i ::% go sh is +      go _ _ = error $ "IsList(ListX): Mismatched list length (type says " +                         ++ show (ssxLength (knownShX @sh)) ++ ", list has length " +                         ++ show (length topl) ++ ")" +  toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where +  type Item (IxX sh i) = i +  fromList = IxX . IsList.fromList +  toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where +  type Item (ShX sh Int) = Int +  fromList topl = ShX (go (knownShX @sh) topl) +    where +      go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) +      go ZKX [] = ZX +      go (SKnown sn :!% sh) (i : is) +        | i == fromSNat' sn = SKnown sn ::% go sh is +        | otherwise = error $ "IsList(ShX): Value does not match typing (type says " +                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" +      go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is +      go _ _ = error $ "IsList(ShX): Mismatched list length (type says " +                         ++ show (ssxLength (knownShX @sh)) ++ ", list has length " +                         ++ show (length topl) ++ ")" +  toList = shxToList diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs new file mode 100644 index 0000000..d77513f --- /dev/null +++ b/src/Data/Array/Mixed/Types.hs @@ -0,0 +1,110 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Types ( +  -- * Reified evidence of a type class +  Dict(..), + +  -- * Type-level naturals +  pattern SZ, pattern SS, +  fromSNat', +  snatPlus, snatMul, + +  -- * Type-level lists +  type (++), +  lemAppNil, +  lemAppAssoc, +  Replicate, +  lemReplicateSucc, + +  -- * Unsafe +  unsafeCoerceRefl, +) where + +import Data.Type.Equality +import Data.Proxy +import GHC.TypeLits +import qualified GHC.TypeNats as TN +import qualified Unsafe.Coerce + + +-- | Evidence for the constraint @c a@. +data Dict c a where +  Dict :: c a => Dict c a + +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) +  where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) +  where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = +  withKnownNat snp1 $ +    case cmpNat (Proxy @1) (Proxy @np1) of +      LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) +      EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) +      GTI -> Nothing + +-- This should be a function in base +snatPlus :: SNat n -> SNat m -> SNat (n + m) +snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + + +-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to +-- only typecheck for actual type equalities. One cannot, e.g. accidentally +-- write this: +-- +-- @ +-- foo :: Proxy a -> Proxy b -> a :~: b +-- foo = unsafeCoerceRefl +-- @ +-- +-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', +-- but would have resulted in interesting memory errors at runtime. +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl + + +-- | Type-level list append. +type family l1 ++ l2 where +  '[] ++ l2 = l2 +  (x : xs) ++ l2 = x : xs ++ l2 + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +type family Replicate n a where +  Replicate 0 a = '[] +  Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl | 
