From a65306ba5d80891b20ac86fa3a3242f9497751e6 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 May 2024 11:58:40 +0200 Subject: Refactor Mixed (modules, regular function names) --- src/Data/Array/Mixed.hs | 757 ++--------------------- src/Data/Array/Mixed/Internal/Arith.hs | 435 +++++++++++++ src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 55 ++ src/Data/Array/Mixed/Internal/Arith/Lists.hs | 78 +++ src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs | 82 +++ src/Data/Array/Mixed/Lemmas.hs | 47 ++ src/Data/Array/Mixed/Permutation.hs | 252 ++++++++ src/Data/Array/Mixed/Shape.hs | 455 ++++++++++++++ src/Data/Array/Mixed/Types.hs | 110 ++++ src/Data/Array/Nested.hs | 10 +- src/Data/Array/Nested/Internal.hs | 326 +++++----- src/Data/Array/Nested/Internal/Arith.hs | 435 ------------- src/Data/Array/Nested/Internal/Arith/Foreign.hs | 55 -- src/Data/Array/Nested/Internal/Arith/Lists.hs | 78 --- src/Data/Array/Nested/Internal/Arith/Lists/TH.hs | 82 --- 15 files changed, 1729 insertions(+), 1528 deletions(-) create mode 100644 src/Data/Array/Mixed/Internal/Arith.hs create mode 100644 src/Data/Array/Mixed/Internal/Arith/Foreign.hs create mode 100644 src/Data/Array/Mixed/Internal/Arith/Lists.hs create mode 100644 src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs create mode 100644 src/Data/Array/Mixed/Lemmas.hs create mode 100644 src/Data/Array/Mixed/Permutation.hs create mode 100644 src/Data/Array/Mixed/Shape.hs create mode 100644 src/Data/Array/Mixed/Types.hs delete mode 100644 src/Data/Array/Nested/Internal/Arith.hs delete mode 100644 src/Data/Array/Nested/Internal/Arith/Foreign.hs delete mode 100644 src/Data/Array/Nested/Internal/Arith/Lists.hs delete mode 100644 src/Data/Array/Nested/Internal/Arith/Lists/TH.hs (limited to 'src/Data/Array') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 4ae89a1..0100ec8 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -30,300 +30,23 @@ module Data.Array.Mixed where import Control.DeepSeq (NFData(..)) import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB -import Data.Bifunctor (first) import Data.Coerce -import qualified Data.Foldable as Foldable -import Data.Functor.Const import Data.Kind -import Data.List (sort) -import Data.Monoid (Sum(..)) import Data.Proxy -import Data.Type.Bool import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) -import GHC.IsList (IsList) -import qualified GHC.IsList as IsList -import GHC.TypeError import GHC.TypeLits -import qualified GHC.TypeNats as TypeNats -import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - -pattern SZ :: () => (n ~ 0) => SNat n -pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) - where SZ = SNat - -pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 -pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) - where SS = snatSucc - -{-# COMPLETE SZ, SS #-} - -snatSucc :: SNat n -> SNat (n + 1) -snatSucc SNat = SNat - -data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) -snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) -snatPred snp1 = - withKnownNat snp1 $ - case cmpNat (Proxy @1) (Proxy @np1) of - LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) - GTI -> Nothing - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - -type family Replicate n a where - Replicate 0 a = '[] - Replicate n a = a : Replicate (n - 1) a - - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -instance (forall n. Show (f n)) => Show (ListX sh f) where - showsPrec _ = showListX shows - -instance (forall n. NFData (f n)) => NFData (ListX sh f) where - rnf ZX = () - rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) -unconsListX ZX = Nothing - -fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -fmapListX _ ZX = ZX -fmapListX f (x ::% xs) = f x ::% fmapListX f xs - -foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -foldListX _ ZX = mempty -foldListX f (x ::% xs) = f x <> foldListX f xs - -lengthListX :: ListX sh f -> Int -lengthListX = getSum . foldListX (\_ -> Sum 1) - -snatLengthListX :: ListX sh f -> SNat (Rank sh) -snatLengthListX ZX = SNat -snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat - -showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -showListX f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListX sh' f -> ShowS - go _ ZX = id - go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listXToList :: ListX sh' (Const i) -> [i] -listXToList ZX = [] -listXToList (Const i ::% is) = i : listXToList is - - -type role IxX nominal representational -type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) - -pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i -pattern ZIX = IxX ZX - -pattern (:.%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) -infixr 3 :.% - -{-# COMPLETE ZIX, (:.%) #-} - -type IIxX sh = IxX sh Int - -instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = showListX (\(Const i) -> shows i) l - -instance Functor (IxX sh) where - fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) - -instance Foldable (IxX sh) where - foldMap f (IxX l) = foldListX (f . getConst) l - -instance NFData i => NFData (IxX sh i) - - -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) - -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where - rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x - -fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int SNat n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type role ShX nominal representational -type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} - -type IShX sh = ShX sh Int - -instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l - -instance Functor (ShX sh) where - fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) - -lengthShX :: ShX sh i -> Int -lengthShX (ShX l) = lengthListX l - -shXToList :: IShX sh -> [Int] -shXToList ZSX = [] -shXToList (smn :$% sh) = fromSMayNat' smn : shXToList sh - - --- | The part of a shape that is statically known. -type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) - -pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX - -pattern (:!%) - :: forall {sh1}. - forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) -infixr 3 :!% - -{-# COMPLETE ZKX, (:!%) #-} - -instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l - -lengthStaticShX :: StaticShX sh -> Int -lengthStaticShX (StaticShX l) = lengthListX l - -geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -geqStaticShX ZKX ZKX = Just Refl -geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') - | Just Refl <- sameNat n m - , Just Refl <- geqStaticShX sh sh' - = Just Refl -geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh') - | Just Refl <- geqStaticShX sh sh' - = Just Refl -geqStaticShX _ _ = Nothing - - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShX :: [Maybe Nat] -> Constraint -class KnownShX sh where knownShX :: StaticShX sh -instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX -instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX - - --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i - fromList topl = go (knownShX @sh) topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "IsList(ListX): Mismatched list length (type says " - ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = listXToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where - type Item (IxX sh i) = i - fromList = IxX . IsList.fromList - toList = Foldable.toList - --- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int - fromList topl = ShX (go (knownShX @sh) topl) - where - go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) - go ZKX [] = ZX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn ::% go sh is - | otherwise = error $ "IsList(ShX): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is - go _ _ = error $ "IsList(ShX): Mismatched list length (type says " - ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = shXToList - - -type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 - type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show, Eq, Generic) @@ -333,180 +56,6 @@ deriving instance (Ord a, Storable a) => Ord (XArray '[] a) instance NFData a => NFData (XArray sh a) -zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKX = ZIX -zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh - -zeroIxX' :: IShX sh -> IIxX sh -zeroIxX' ZSX = ZIX -zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh - --- This is a weird operation, so it has a long name -completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKX = ZSX -completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh -completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixAppend = coerce (listxAppend @_ @(Const i)) - -shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixDrop = coerce (listxDrop @(Const i) @(Const i)) - -shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) - -shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) - -shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) - -shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh - -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) - --- TODO: generalise all these things to arbitrary @i@ -shTail :: IShX (n : sh) -> IShX sh -shTail (_ :$% sh) = sh - -ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh - -shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKX idx = (ZSX, idx) -shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' - -shapeSize :: IShX sh -> Int -shapeSize ZSX = 1 -shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKX = Just ZSX -ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh -ssxToShape' (SUnknown _ :!% _) = Nothing - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerce Refl - -ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKX -ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' - = SUnknown () :!% ssxReplicate n - -fromLinearIdx :: IShX sh -> Int -> IIxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) - -toLinearIdx :: IShX sh -> IIxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) - -enumShape :: IShX sh -> [IIxX sh] -enumShape = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shapeLshape :: IShX sh -> S.ShapeL -shapeLshape ZSX = [] -shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh - -ssxLength :: StaticShX sh -> Int -ssxLength ZKX = 0 -ssxLength (_ :!% ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh - -type Flatten sh = Flatten' 1 sh - -type family Flatten' acc sh where - Flatten' acc '[] = Just acc - Flatten' acc (Nothing : sh) = Nothing - Flatten' acc (Just n : sh) = Flatten' (acc * n) sh - -flattenSSX :: StaticShX sh -> SMayNat () SNat (Flatten sh) -flattenSSX = go (SNat @1) - where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) - go acc ZKX = SKnown acc - go _ (SUnknown () :!% _) = SUnknown () - go acc (SKnown sn :!% sh) = go (mulSNat acc sn) sh - -flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh) -flattenSh = go (SNat @1) - where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) - go acc ZSX = SKnown acc - go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) - go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh - - goUnknown :: Int -> IShX sh -> Int - goUnknown acc ZSX = acc - goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh - goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh - -staticShapeFrom :: IShX sh -> StaticShX sh -staticShapeFrom ZSX = ZKX -staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh - -lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRank ZSX = Dict -lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict - -lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX ZKX = Dict -lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) @@ -519,7 +68,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh - = XArray (S.fromVector (shapeLshape sh) v) + = XArray (S.fromVector (shxToList sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr @@ -527,23 +76,18 @@ toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar -eqShX :: IShX sh1 -> IShX sh2 -> Bool -eqShX ZSX ZSX = True -eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2 -eqShX _ _ = False - -- | Will throw if the array does not have the casted-to shape. cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> StaticShX sh' -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' - , Refl <- lemRankApp (staticShapeFrom sh2) ssh' + , Refl <- lemRankApp (ssxFromShape sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) - in if eqShX arrsh sh2 - then XArray arr - else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in case shxEqual arrsh sh2 of + Just _ -> XArray arr + Nothing -> error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a @@ -551,24 +95,24 @@ unScalar (XArray a) = S.unScalar a replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a replicate sh ssh' (XArray arr) | Dict <- lemKnownNatRankSSX ssh' - , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh') - , Refl <- lemRankApp (staticShapeFrom sh) ssh' - = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $ - S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $ + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh') + , Refl <- lemRankApp (ssxFromShape sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $ arr) replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a replicateScal sh x | Dict <- lemKnownNatRank sh - = XArray (S.constant (shapeLshape sh) x) + = XArray (S.constant (shxToList sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . S.fromVector (shapeLshape sh) --- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr @@ -580,24 +124,6 @@ index xarr i = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - --- This should be a function in base -plusSNat :: SNat n -> SNat m -> SNat (n + m) -plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce - --- This should be a function in base -mulSNat :: SNat n -> SNat m -> SNat (n * m) -mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce - -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) - append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) @@ -639,9 +165,9 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) - in if any (== 0) (shapeLshape sh) - then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 @@ -666,9 +192,9 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) - in if any (== 0) (shapeLshape sh) - then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if any (== 0) (shxToList sh) + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of () | Dict <- lemKnownNatRankSSX ssh , Dict <- lemKnownNatRankSSX ssh2 @@ -678,211 +204,14 @@ rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) (\a b -> let XArray r = f (XArray a) (XArray b) in r) arr1 arr2) -type family Elem x l where - Elem x '[] = 'False - Elem x (x : _) = 'True - Elem x (_ : ys) = Elem x ys - -type family AllElem' as bs where - AllElem' '[] bs = 'True - AllElem' (a : as) bs = Elem a bs && AllElem' as bs - -type AllElem as bs = Assert (AllElem' as bs) - (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) - -type family Count i n where - Count n n = '[] - Count i n = i : Count (i + 1) n - -type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) - -type family Index i sh where - Index 0 (n : sh) = n - Index i (_ : sh) = Index (i - 1) sh - -type family Permute is sh where - Permute '[] sh = '[] - Permute (i : is) sh = Index i sh : Permute is sh - -type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh - -data HList f list where - HNil :: HList f '[] - HCons :: f a -> HList f l -> HList f (a : l) -infixr 5 `HCons` -deriving instance (forall a. Show (f a)) => Show (HList f list) -deriving instance (forall a. Eq (f a)) => Eq (HList f list) - -foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m -foldHList _ HNil = mempty -foldHList f (x `HCons` l) = f x <> foldHList f l - -snatLengthHList :: HList f list -> SNat (Rank list) -snatLengthHList HNil = SNat -snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat - -permFromList :: [Int] -> (forall list. HList SNat list -> r) -> r -permFromList [] k = k HNil -permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case - Just sn -> permFromList xs $ \list -> k (sn `HCons` list) - Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x - -permToList :: HList SNat list -> [Int] -permToList = foldHList (pure . fromSNat') - -type family TakeLen ref l where - TakeLen '[] l = '[] - TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs - -type family DropLen ref l where - DropLen '[] l = l - DropLen (_ : ref) (_ : xs) = DropLen ref xs - -lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is -lemRankPermute _ HNil = Refl -lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl - -lemRankDropLen :: forall is sh. (Rank is <= Rank sh) - => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKX HNil = Refl -lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!% _) HNil = Refl -lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" - -lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l -lemIndexSucc _ _ _ = unsafeCoerce Refl - -listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen HNil _ = ZX -listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen HNil sh = sh -listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape" - -listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f -listxPermute HNil _ = ZX -listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) - -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f -listxIndex _ _ SZ (n ::% _) rest = n ::% rest -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh rest -listxIndex _ _ _ ZX _ = error "Index into empty shape" - -listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) - -ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) - -ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) - -ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) - -shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh) -shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) - --- TODO: test this thing more properly -invertPermutation :: HList SNat is - -> (forall is'. - Permutation is' - => HList SNat is' - -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) - -> r) - -> r -invertPermutation = \perm k -> - genPerm perm $ \(invperm :: HList SNat is') -> - let sn = snatLengthHList invperm - in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of - (Just Refl, Just Refl) -> - k invperm - (\ssh -> case provePermInverse perm invperm ssh of - Just eq -> eq - Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm) - _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm - ++ " ; invperm = " ++ show invperm - where - genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r - genPerm perm = - let permList = foldHList (pure . fromSNat) perm - in toHList $ map snd (sort (zip permList [0..])) - where - toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r - toHList [] k = k HNil - toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l) - - lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True - lemElemCount _ _ = unsafeCoerce Refl - - lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n - lemCount _ _ = unsafeCoerce Refl - - lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True - lemElem _ _ = unsafeCoerce Refl - - provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is' - -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) - provePerm1 _ _ HNil = Just (Refl) - provePerm1 p rtop@SNat (HCons sn@SNat perm) - | Just Refl <- provePerm1 p rtop perm - = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of - (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl - _ -> Nothing - | otherwise - = Nothing - - provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True) - provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> - case cmpNat i n of - EQI -> Just Refl - LTI | Refl <- lemCount i n - , Just Refl <- provePerm2 (SNat @(i + 1)) n perm - -> checkElem i perm - | otherwise -> Nothing - GTI -> error "unreachable" - where - checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True) - checkElem _ HNil = Nothing - checkElem i@SNat (HCons k@SNat perm :: HList SNat is') = - case sameNat i k of - Just Refl -> Just Refl - Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl - | otherwise -> Nothing - - provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh) - provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh - -applyPermX :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f -applyPermX perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) - -applyPermIxX :: forall i is sh. HList SNat is -> IxX sh i -> IxX (PermutePrefix is sh) i -applyPermIxX = coerce (applyPermX @(Const i)) - -applyPermShX :: forall i is sh. HList SNat is -> ShX sh i -> ShX (PermutePrefix is sh) i -applyPermShX = coerce (applyPermX @(SMayNat i SNat)) - -class KnownNatList l where makeNatList :: HList SNat l -instance KnownNatList '[] where makeNatList = HNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList +class KnownNatList l where makeNatList :: Perm l +instance KnownNatList '[] where makeNatList = PNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList -- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) => StaticShX sh - -> HList SNat is + -> Perm is -> XArray sh a -> XArray (PermutePrefix is sh) a transpose ssh perm (XArray arr) @@ -890,7 +219,7 @@ transpose ssh perm (XArray arr) , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm , Refl <- lemRankDropLen ssh perm - = XArray (S.transpose (permToList perm) arr) + = XArray (S.transpose (permToList' perm) arr) -- | The list argument gives indices into the original dimension list. -- @@ -929,14 +258,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - sh'F = flattenSh sh' :$% ZSX - ssh'F = staticShapeFrom sh'F + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShape sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = snatLengthListX (let StaticShX l = ssh in l) + , let sn = listxLengthSNat (let StaticShX l = ssh in l) = XArray (numEltSum1Inner sn arr') in go $ @@ -949,10 +278,10 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) - shF = flattenSh sh :$% ZSX - in sumInner ssh' (staticShapeFrom shF) $ - transpose2 (staticShapeFrom shF) ssh' $ + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShape shF) $ + transpose2 (ssxFromShape shF) ssh' $ reshapePartial ssh ssh' shF $ arr @@ -988,7 +317,7 @@ toList1 (XArray arr) = S.toList arr empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh | Dict <- lemKnownNatRank sh - = XArray (S.constant (shapeLshape sh) + = XArray (S.constant (shxToList sh) (error "Data.Array.Mixed.empty: shape was not empty")) slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a @@ -1005,14 +334,14 @@ reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray s reshape ssh1 sh2 (XArray arr) | Dict <- lemKnownNatRankSSX ssh1 , Dict <- lemKnownNatRank sh2 - = XArray (S.reshape (shapeLshape sh2) arr) + = XArray (S.reshape (shxToList sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') - , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') - = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthStaticShX ssh1) (S.shapeL arr)) arr) + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) -- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs new file mode 100644 index 0000000..cf6820b --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -0,0 +1,435 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Internal.Arith where + +import Control.Monad (forM, guard) +import qualified Data.Array.Internal as OI +import qualified Data.Array.Internal.RankedG as RG +import qualified Data.Array.Internal.RankedS as RS +import Data.Bits +import Data.Int +import Data.List (sort) +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.Ptr +import Foreign.Storable (Storable) +import GHC.TypeLits +import Language.Haskell.TH +import System.IO.Unsafe + +import Data.Array.Mixed.Internal.Arith.Foreign +import Data.Array.Mixed.Internal.Arith.Lists + + +liftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just prefixSz <- stridesDense sh strides = + let vec' = f (VS.slice offset prefixSz vec) + in RS.A (RG.A sh (OI.T strides 0 vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) + +liftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f + arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) + arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) + | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs + | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of + (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars + let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) + in RS.A (RG.A sh1 (OI.T strides1 0 vec')) + (Just 1, Just n) -> -- scalar * dense + RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) + (Just n, Just 1) -> -- dense * scalar + RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) + (_, _) -> -- fallback case + RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) + +-- | Given the shape vector and the stride vector, return whether this vector +-- of strides uses a dense prefix of its backing array. If so, the number of +-- elements in this prefix is returned. +-- This excludes any offset. +stridesDense :: [Int] -> [Int] -> Maybe Int +stridesDense sh _ | any (<= 0) sh = Just 0 +stridesDense sh str = + -- sort dimensions on their stride, ascending, dropping any zero strides + case dropWhile ((== 0) . fst) (sort (zip str sh)) of + [] -> Just 1 + (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' + _ -> Nothing -- if the smallest stride is not 1, it will never be dense + where + -- Given size of currently densely covered region at beginning of the + -- array, the remaining shape vector and the corresponding remaining stride + -- vector, return whether this all together covers a dense prefix of the + -- array. If it does, return the number of elements in this prefix. + checkCover :: Int -> [Int] -> [Int] -> Maybe Int + checkCover block [] [] = Just block + checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' + checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" + +{-# NOINLINE vectorOp1 #-} +vectorOp1 :: forall a b. Storable a + => (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO ()) + -> VS.Vector a -> VS.Vector a +vectorOp1 ptrconv f v = unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length v) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith v $ \pv -> + f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) + VS.unsafeFreeze outv + +-- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} +vectorOp2 :: forall a b. Storable a + => (a -> b) + -> (Ptr a -> Ptr b) + -> (a -> a -> a) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv + -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs + -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv + -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a +vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases + (Left x) (Left y) -> VS.singleton (fss x y) + + (Left x) (Right vy) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vy) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vy $ \pvy -> + fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) + VS.unsafeFreeze outv + + (Right vx) (Left y) -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) + VS.unsafeFreeze outv + + (Right vx) (Right vy) + | VS.length vx == VS.length vy -> + unsafePerformIO $ do + outv <- VSM.unsafeNew (VS.length vx) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith vx $ \pvx -> + VS.unsafeWith vy $ \pvy -> + fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) + VS.unsafeFreeze outv + | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) + +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) + | null sh = error "unreachable" + | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) + | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) + -- now the input array is nonempty + | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) + | last strides == 0 = + liftVEltwise1 sn + (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) + (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) + -- now there is useful work along the inner dimension + | otherwise = + let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those + (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) + ndimsF = length shF + in unsafePerformIO $ do + outv <- VSM.unsafeNew (product (init shF)) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> + fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) + RS.fromVector (init sh) <$> VS.unsafeFreeze outv + +flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) + -> Int64 -> Ptr a -> Ptr a -> a -> IO () +flipOp f n out v s = f n out s v + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_binary_" ++ atCName arithtype + c_ss = varE (aboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) + cnamebase = "c_fbinary_" ++ atCName arithtype + c_ss = varE (afboNumOp arithop) + c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM [minBound..maxBound] $ \arithop -> do + let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] + return $ FunD name [Clause [] (NormalB body) []]]) + +-- This branch is ostensibly a runtime branch, but will (hopefully) be +-- constant-folded away by GHC. +intWidBranch1 :: forall i n. (FiniteBits i, Storable i) + => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) + -> (SNat n -> RS.Array n i -> RS.Array n i) +intWidBranch1 f32 f64 sn + | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | otherwise = error "Unsupported Int width" + +intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) + => (i -> i -> i) -- ss + -- int32 + -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv + -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) +intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn + | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 + | otherwise = error "Unsupported Int width" + +class NumElt a where + numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a + numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a + numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + +instance NumElt Int32 where + numEltAdd = addVectorInt32 + numEltSub = subVectorInt32 + numEltMul = mulVectorInt32 + numEltNeg = negVectorInt32 + numEltAbs = absVectorInt32 + numEltSignum = signumVectorInt32 + numEltSum1Inner = sum1VectorInt32 + numEltProduct1Inner = product1VectorInt32 + +instance NumElt Int64 where + numEltAdd = addVectorInt64 + numEltSub = subVectorInt64 + numEltMul = mulVectorInt64 + numEltNeg = negVectorInt64 + numEltAbs = absVectorInt64 + numEltSignum = signumVectorInt64 + numEltSum1Inner = sum1VectorInt64 + numEltProduct1Inner = product1VectorInt64 + +instance NumElt Float where + numEltAdd = addVectorFloat + numEltSub = subVectorFloat + numEltMul = mulVectorFloat + numEltNeg = negVectorFloat + numEltAbs = absVectorFloat + numEltSignum = signumVectorFloat + numEltSum1Inner = sum1VectorFloat + numEltProduct1Inner = product1VectorFloat + +instance NumElt Double where + numEltAdd = addVectorDouble + numEltSub = subVectorDouble + numEltMul = mulVectorDouble + numEltNeg = negVectorDouble + numEltAbs = absVectorDouble + numEltSignum = signumVectorDouble + numEltSum1Inner = sum1VectorDouble + numEltProduct1Inner = product1VectorDouble + +instance NumElt Int where + numEltAdd = intWidBranch2 @Int (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @Int (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @Int (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @Int + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +instance NumElt CInt where + numEltAdd = intWidBranch2 @CInt (+) + (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) + (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) + numEltSub = intWidBranch2 @CInt (-) + (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) + (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) + numEltMul = intWidBranch2 @CInt (*) + (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) + (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) + numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) + numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) + numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) + numEltSum1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) + numEltProduct1Inner = intWidBranchRed @CInt + (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + +class FloatElt a where + floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a + floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a + floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a + floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a + floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a + floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a + floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a + floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a + +instance FloatElt Float where + floatEltDiv = divVectorFloat + floatEltPow = powVectorFloat + floatEltLogbase = logbaseVectorFloat + floatEltRecip = recipVectorFloat + floatEltExp = expVectorFloat + floatEltLog = logVectorFloat + floatEltSqrt = sqrtVectorFloat + floatEltSin = sinVectorFloat + floatEltCos = cosVectorFloat + floatEltTan = tanVectorFloat + floatEltAsin = asinVectorFloat + floatEltAcos = acosVectorFloat + floatEltAtan = atanVectorFloat + floatEltSinh = sinhVectorFloat + floatEltCosh = coshVectorFloat + floatEltTanh = tanhVectorFloat + floatEltAsinh = asinhVectorFloat + floatEltAcosh = acoshVectorFloat + floatEltAtanh = atanhVectorFloat + floatEltLog1p = log1pVectorFloat + floatEltExpm1 = expm1VectorFloat + floatEltLog1pexp = log1pexpVectorFloat + floatEltLog1mexp = log1mexpVectorFloat + +instance FloatElt Double where + floatEltDiv = divVectorDouble + floatEltPow = powVectorDouble + floatEltLogbase = logbaseVectorDouble + floatEltRecip = recipVectorDouble + floatEltExp = expVectorDouble + floatEltLog = logVectorDouble + floatEltSqrt = sqrtVectorDouble + floatEltSin = sinVectorDouble + floatEltCos = cosVectorDouble + floatEltTan = tanVectorDouble + floatEltAsin = asinVectorDouble + floatEltAcos = acosVectorDouble + floatEltAtan = atanVectorDouble + floatEltSinh = sinhVectorDouble + floatEltCosh = coshVectorDouble + floatEltTanh = tanhVectorDouble + floatEltAsinh = asinhVectorDouble + floatEltAcosh = acoshVectorDouble + floatEltAtanh = atanhVectorDouble + floatEltLog1p = log1pVectorDouble + floatEltExpm1 = expm1VectorDouble + floatEltLog1pexp = log1pexpVectorDouble + floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs new file mode 100644 index 0000000..6fc7229 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Foreign where + +import Control.Monad +import Data.Int +import Data.Maybe +import Foreign.C.Types +import Foreign.Ptr +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists + + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "binary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "fbinary_" ++ atCName arithtype + sequence $ catMaybes + [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) + ]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "unary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM floatTypesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "funary_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "reduce_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs new file mode 100644 index 0000000..a284bc1 --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Mixed.Internal.Arith.Lists where + +import Data.Char +import Data.Int +import Language.Haskell.TH + +import Data.Array.Mixed.Internal.Arith.Lists.TH + + +data ArithType = ArithType + { atType :: Name -- ''Int32 + , atCName :: String -- "i32" + } + +floatTypesList :: [ArithType] +floatTypesList = + [ArithType ''Float "float" + ,ArithType ''Double "double" + ] + +typesList :: [ArithType] +typesList = + [ArithType ''Int32 "i32" + ,ArithType ''Int64 "i64" + ] + ++ floatTypesList + +-- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) +$(genArithDataType Binop "ArithBOp") + +$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) +$(genArithEnumFun Binop ''ArithBOp "aboEnum") + +$(do clauses <- readArithLists Binop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] + ,return $ FunD (mkName "aboNumOp") clauses]) + + +-- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) +$(genArithDataType FBinop "ArithFBOp") + +$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) +$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") + +$(do clauses <- readArithLists FBinop + (\name _num hsop -> return (Clause [ConP (mkName name) [] []] + (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) + [])) + return + sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] + ,return $ FunD (mkName "afboNumOp") clauses]) + + +-- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) +$(genArithDataType Unop "ArithUOp") + +$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) +$(genArithEnumFun Unop ''ArithUOp "auoEnum") + + +-- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) +$(genArithDataType FUnop "ArithFUOp") + +$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) +$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") + + +-- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) +$(genArithDataType Redop "ArithRedOp") + +$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) +$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs new file mode 100644 index 0000000..8b7d05f --- /dev/null +++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Mixed.Internal.Arith.Lists.TH where + +import Control.Monad +import Control.Monad.IO.Class +import Data.Maybe +import Foreign.C.Types +import Language.Haskell.TH +import Language.Haskell.TH.Syntax +import Text.Read + + +data OpKind = Binop | FBinop | Unop | FUnop | Redop + deriving (Show, Eq) + +readArithLists :: OpKind + -> (String -> Int -> String -> Q a) + -> ([a] -> Q r) + -> Q r +readArithLists targetkind fop fcombine = do + addDependentFile "cbits/arith_lists.h" + lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" + + mvals <- forM lns $ \line -> do + if null (dropWhile (== ' ') line) + then return Nothing + else do let (kind, name, num, aux) = parseLine line + if kind == targetkind + then Just <$> fop name num aux + else return Nothing + + fcombine (catMaybes mvals) + where + parseLine s0 + | ("LIST_", s1) <- splitAt 5 s0 + , (kindstr, '(' : s2) <- break (== '(') s1 + , (f1, ',' : s3) <- parseField s2 + , (f2, ',' : s4) <- parseField s3 + , (f3, ')' : _) <- parseField s4 + , Just kind <- parseKind kindstr + , let name = f1 + , Just num <- readMaybe f2 + , let aux = f3 + = (kind, name, num, aux) + | otherwise + = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 + + parseField s = break (`elem` ",)") (dropWhile (== ' ') s) + + parseKind "BINOP" = Just Binop + parseKind "FBINOP" = Just FBinop + parseKind "UNOP" = Just Unop + parseKind "FUNOP" = Just FUnop + parseKind "REDOP" = Just Redop + parseKind _ = Nothing + +genArithDataType :: OpKind -> String -> Q [Dec] +genArithDataType kind dtname = do + cons <- readArithLists kind + (\name _num _ -> return $ NormalC (mkName name) []) + return + return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] + +genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] +genArithNameFun kind dtname funname nametrans = do + clauses <- readArithLists kind + (\name _num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (StringL (nametrans name)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) + ,FunD (mkName funname) clauses] + +genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] +genArithEnumFun kind dtname funname = do + clauses <- readArithLists kind + (\name num _ -> return (Clause [ConP (mkName name) [] []] + (NormalB (LitE (IntegerL (fromIntegral num)))) + [])) + return + return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) + ,FunD (mkName funname) clauses] diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs new file mode 100644 index 0000000..30ec9c0 --- /dev/null +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DataKinds #-} +module Data.Array.Mixed.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $ + lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $ + lemRankApp ssh1 ssh2 + where + lem :: proxy a -> proxy b -> proxy c + -> c :~: b + a + -> b + a :~: c + lem _ _ _ Refl = Refl + + lem2 :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem2 _ _ _ Refl = Refl + +lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this + +lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs new file mode 100644 index 0000000..2710018 --- /dev/null +++ b/src/Data/Array/Mixed/Permutation.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.TypeError +import GHC.TypeLits +import qualified GHC.TypeNats as TN + +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute' +-- for code that implements this. +data Perm list where + PNil :: Perm '[] + PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +permLengthSNat :: Perm list -> SNat (Rank list) +permLengthSNat PNil = SNat +permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromList xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + + +-- ** Applying permutations + +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f +listxIndex _ _ SZ (n ::% _) rest = n ::% rest +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh rest +listxIndex _ _ _ ZX _ = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +-- TODO: test this thing more properly +permInverse :: Perm is + -> (forall is'. + IsPermutation is' + => Perm is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +permInverse = \perm k -> + genPerm perm $ \(invperm :: Perm is') -> + let sn = permLengthSNat invperm + in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of + (Just Refl, Just Refl) -> + k invperm + (\ssh -> case provePermInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm + where + genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r + genPerm perm = + let permList = permToList' perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r + toHList [] k = k PNil + toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + + lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just (Refl) + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + + provePermInverse :: Perm is -> Perm is' -> StaticShX sh + -> Maybe (Permute is' (Permute is sh) :~: sh) + provePermInverse perm perminv ssh = + ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where + MapSucc '[] = '[] + MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc + where + permMapSucc :: Perm l -> Perm (MapSucc l) + permMapSucc PNil = PNil + permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l + -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs new file mode 100644 index 0000000..a16da76 --- /dev/null +++ b/src/Data/Array/Mixed/Shape.hs @@ -0,0 +1,455 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import qualified Data.Foldable as Foldable +import Data.Functor.Const +import Data.Kind (Type, Constraint) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import qualified GHC.IsList as IsList +import GHC.TypeLits + +import Data.Array.Mixed.Types +import Data.Coerce +import Data.Bifunctor (first) + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxLengthSNat :: ListX sh f -> SNat (Rank sh) +listxLengthSNat ZX = SNat +listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short + + +-- * Mixed indices + +-- | This is a newtype over 'ListX'. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +type IIxX sh = IxX sh Int + +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +-- | This is more than @geq@: it also checks that the integers (the unknown +-- dimensions) are the same. +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ :$% sh) = sh + +shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i +shxTakeSSX _ = flip go + where + go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i + go ZKX _ = ZSX + go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ +-- class of the @some@ package. +ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxGeq ZKX ZKX = Just Refl +ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') + | Just Refl <- sameNat n m + , Just Refl <- ssxGeq sh sh' + = Just Refl +ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') + | Just Refl <- ssxGeq sh sh' + = Just Refl +ssxGeq _ _ = Nothing + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShX' :: StaticShX sh -> Maybe (IShX sh) +ssxToShX' ZKX = Just ZSX +ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh +ssxToShX' (SUnknown _ :!% _) = Nothing + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh + +ssxFromShape :: IShX sh -> StaticShX sh +ssxFromShape ZSX = ZKX +ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList topl = go (knownShX @sh) topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "IsList(ListX): Mismatched list length (type says " + ++ show (ssxLength (knownShX @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList topl = ShX (go (knownShX @sh) topl) + where + go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) + go ZKX [] = ZX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn ::% go sh is + | otherwise = error $ "IsList(ShX): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is + go _ _ = error $ "IsList(ShX): Mismatched list length (type says " + ++ show (ssxLength (knownShX @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shxToList diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs new file mode 100644 index 0000000..d77513f --- /dev/null +++ b/src/Data/Array/Mixed/Types.hs @@ -0,0 +1,110 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Types ( + -- * Reified evidence of a type class + Dict(..), + + -- * Type-level naturals + pattern SZ, pattern SS, + fromSNat', + snatPlus, snatMul, + + -- * Type-level lists + type (++), + lemAppNil, + lemAppAssoc, + Replicate, + lemReplicateSucc, + + -- * Unsafe + unsafeCoerceRefl, +) where + +import Data.Type.Equality +import Data.Proxy +import GHC.TypeLits +import qualified GHC.TypeNats as TN +import qualified Unsafe.Coerce + + +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a + +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + +-- This should be a function in base +snatPlus :: SNat n -> SNat m -> SNat (n + m) +snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + + +-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to +-- only typecheck for actual type equalities. One cannot, e.g. accidentally +-- write this: +-- +-- @ +-- foo :: Proxy a -> Proxy b -> a :~: b +-- foo = unsafeCoerceRefl +-- @ +-- +-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', +-- but would have resulted in interesting memory errors at runtime. +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index e3af0ee..1a4e094 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -58,8 +58,9 @@ module Data.Array.Nested ( type (++), Storable, SNat, pattern SNat, - HList, - Permutation, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, KnownNatList(..), listSToList, shSToList, @@ -69,7 +70,10 @@ module Data.Array.Nested ( import Prelude hiding (mappend) import Data.Array.Mixed +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types import Data.Array.Nested.Internal -import Data.Array.Nested.Internal.Arith import Foreign.Storable import GHC.TypeLits diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 712c5f1..0870789 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -60,7 +60,11 @@ import Unsafe.Coerce import Data.Array.Mixed import qualified Data.Array.Mixed as X -import Data.Array.Nested.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Types -- Invariant in the API @@ -123,19 +127,19 @@ lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) -lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate SZ = Refl lemRankReplicate (SS (n :: SNat nm1)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 + | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 , Refl <- lemRankReplicate n = Refl -lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh lemRankMapJust ZSS = Refl lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl @@ -146,9 +150,9 @@ lemReplicatePlusApp sn _ _ = go sn go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl go (SS (n :: SNat n'm1)) - | Refl <- X.lemReplicateSucc @a @n'm1 + | Refl <- lemReplicateSucc @a @n'm1 , Refl <- go n - = sym (X.lemReplicateSucc @a @(n'm1 + m)) + = sym (lemReplicateSucc @a @(n'm1 + m)) lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True lemLeqPlus _ _ _ = Refl @@ -156,17 +160,17 @@ lemLeqPlus _ _ _ = Refl lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True lemLeqSuccSucc _ _ = unsafeCoerce Refl -lemDropLenApp :: X.Rank l1 <= X.Rank l2 +lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest - -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest) + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) lemDropLenApp _ _ _ = unsafeCoerce Refl -lemTakeLenApp :: X.Rank l1 <= X.Rank l2 +lemTakeLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest - -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) lemTakeLenApp _ _ _ = unsafeCoerce Refl -srankSh :: ShX sh f -> SNat (X.Rank sh) +srankSh :: ShX sh f -> SNat (Rank sh) srankSh ZSX = SNat srankSh (_ :$% sh) | SNat <- srankSh sh = SNat @@ -585,11 +589,11 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) - => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a -- ====== PRIVATE METHODS ====== -- @@ -635,20 +639,20 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr) + in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) mlift ssh2 f (M_Primitive _ a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result @@ -657,36 +661,36 @@ instance Storable a => Elt (Primitive a) where -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - , Refl <- X.lemAppNil @sh3 + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result - mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1' - in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr) + let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) mtranspose perm (M_Primitive sh arr) = - M_Primitive (X.shPermutePrefix perm sh) - (X.transpose (X.staticShapeFrom sh) perm arr) + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShape sh) perm arr) mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (X.staticShapeFrom sh') arr - offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh)) - VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr) + let arrsh = X.shape (ssxFromShape sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v @@ -701,7 +705,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArray sh = M_Primitive sh (X.empty sh) - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -755,7 +759,7 @@ instance Elt a => Elt (Mixed sh' a) where -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr)) + = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) i = mindexPartial arr i @@ -763,8 +767,8 @@ instance Elt a => Elt (Mixed sh' a) where mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -773,95 +777,95 @@ instance Elt a => Elt (Mixed sh' a) where M_Nest (SUnknown (length l) :$% mshape arr) (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr) + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = - let result = mlift (X.ssxAppend ssh2 ssh') f' arr - (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result) + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr))) + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - = f (X.ssxAppend ssh' sshT) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = - let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result) + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1))) + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - = f (X.ssxAppend ssh' sshT) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) - mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2 + mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) mcast ssh1 sh2 _ (M_Nest sh1T arr) - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T - in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - - mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh) - => HList SNat is -> Mixed sh (Mixed sh' a) - -> Mixed (X.PermutePrefix is sh) (Mixed sh' a) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = X.shDropSh @sh @sh' (mshape arr) sh - , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh') - , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh')) - , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + | let sh' = shxDropSh @sh @sh' (mshape arr) sh + , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') - = M_Nest (X.shPermutePrefix perm sh) + = M_Nest (shxPermutePrefix perm sh) (mtranspose perm arr) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where - memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) + memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) where sh' = mshape example - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -- | Create an array given a size and a function that computes the element at a @@ -882,10 +886,10 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case X.enumShape sh of +mgenerate sh f = case shxEnum sh of [] -> memptyArray sh firstidx : restidxs -> - let firstelem = f (X.zeroIxX' sh) + let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem in if mshapeTreeEmpty (Proxy @a) shapetree then memptyArray sh @@ -905,28 +909,28 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr) + in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive mappend :: forall n m sh a. Elt a - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 where sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 - ssh = X.staticShapeFrom sh - snm :: SMayNat () SNat (X.AddMaybe n m) + ssh = ssxFromShape sh + snm :: SMayNat () SNat (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () - (SKnown n, SKnown m) -> SKnown (X.plusSNat n m) + (SKnown n, SKnown m) -> SKnown (snatPlus n m) f :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b - f ssh' = X.append (X.ssxAppend ssh ssh') + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) @@ -971,9 +975,9 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shDropSSX sh ssh - in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2) + let sh1 = shxDropSSX sh ssh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) + (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) @@ -988,10 +992,10 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a mreplicate sh arr = - let ssh' = X.staticShapeFrom (mshape arr) - in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh') + let ssh' = ssxFromShape (mshape arr) + in mlift (ssxAppend (ssxFromShape sh) ssh') (\(sshT :: StaticShX shT) -> - case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr @@ -1005,18 +1009,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' arr = - mlift (X.staticShapeFrom sh') - (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh') + mlift (ssxFromShape sh') + (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') arr miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a @@ -1095,26 +1099,26 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where log1pexp = mliftNumElt1 floatEltLog1pexp log1mexp = mliftNumElt1 floatEltLog1mexp -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked arr - | Refl <- X.lemAppNil @sh - , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat)) + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) , Refl <- lemRankReplicate (srankSh (mshape arr)) - = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) where - convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing) + convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) convSh ZSX = ZSX convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) = SUnknown (fromSMayNat' smn) :$% convSh sh -mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh') +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a mcastToShaped arr targetsh - | Refl <- X.lemAppNil @sh - , Refl <- X.lemAppNil @(MapJust sh') + | Refl <- lemAppNil @sh + , Refl <- lemAppNil @(MapJust sh') , Refl <- lemRankMapJust targetsh - = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -1418,7 +1422,7 @@ zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR zeroIxR (SS n) = 0 :.: zeroIxR n -ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) +ixCvtXR :: IIxX sh -> IIxR (Rank sh) ixCvtXR ZIX = ZIR ixCvtXR (n :.% idx) = n :.: ixCvtXR idx @@ -1429,7 +1433,7 @@ shCvtXR' ZSX = shCvtXR' (n :$% (idx :: IShX sh)) | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = castWith (subst2 (lem1 @sh Refl)) - (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) where lem1 :: forall sh' n' k. k : sh' :~: Replicate n' Nothing @@ -1443,13 +1447,13 @@ shCvtXR' (n :$% (idx :: IShX sh)) ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (n :.% ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) + castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) (SUnknown n :$% shCvtRX idx) shapeSizeR :: IShR n -> Int @@ -1506,7 +1510,7 @@ rsumOuter1P :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (msumOuter1P arr) rsumOuter1 :: forall n a. (NumElt a, PrimElt a) @@ -1559,7 +1563,7 @@ rappend :: forall n a. Elt a rappend arr1 arr2 | sn@SNat <- snatFromShR (rshape arr1) , Dict <- lemKnownReplicate sn - , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + , Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) arr1 arr2 @@ -1582,7 +1586,7 @@ rtoVector = coerce mtoVector rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromListOuter l - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a @@ -1593,7 +1597,7 @@ rfromList1Prim l = Ranked (mfromList1Prim l) rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] rtoListOuter (Ranked arr) - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] @@ -1677,7 +1681,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a rslice i n arr - | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + | Refl <- lemReplicateSucc @(Nothing @Nat) @n = rlift (snatFromShR (rshape arr)) (\_ -> X.sliceU i n) arr @@ -1686,7 +1690,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a rrev1 arr = rlift (snatFromShR (rshape arr)) (\(_ :: StaticShX sh') -> - case X.lemReplicateSucc @(Nothing @Nat) @n of + case lemReplicateSucc @(Nothing @Nat) @n of Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) arr @@ -1707,12 +1711,12 @@ rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) -rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) , Refl <- lemRankMapJust targetsh @@ -1809,7 +1813,7 @@ shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh shapeSizeS :: ShS sh -> Int shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh +shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh @@ -1838,14 +1842,14 @@ slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr) +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) -- | See the documentation of 'mlift'. slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2) +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) ssumOuter1P :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) @@ -1855,28 +1859,28 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive -lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) -lemCommMapJustTakeLen HNil _ = Refl -lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl -lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" +lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemCommMapJustTakeLen PNil _ = Refl +lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" -lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) -lemCommMapJustDropLen HNil _ = Refl -lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl -lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" +lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemCommMapJustDropLen PNil _ = Refl +lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" -lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) lemCommMapJustIndex SZ (_ :$$ _) = Refl lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) | Refl <- lemCommMapJustIndex i sh - , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = Refl lemCommMapJustIndex _ ZSS = error "Index of empty" -lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) -lemCommMapJustPermute HNil _ = Refl -lemCommMapJustPermute (i `HCons` is) sh +lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemCommMapJustPermute PNil _ = Refl +lemCommMapJustPermute (i `PCons` is) sh | Refl <- lemCommMapJustPermute is sh , Refl <- lemCommMapJustIndex i sh = Refl @@ -1885,53 +1889,53 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f -listsTakeLen HNil _ = ZS -listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh -listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLen PNil _ = ZS +listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh +listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLen HNil sh = sh -listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh -listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape" +listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLen PNil sh = sh +listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh +listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f -listsPermute HNil _ = ZS -listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest - | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = listsIndex p pT i sh rest listsIndex _ _ _ ZS _ = error "Index into empty shape" -shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) shsTakeLen = coerce (listsTakeLen @SNat) -shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) shsPermute = coerce (listsPermute @SNat) -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT) shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) -applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) -applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i applyPermIxS = coerce (applyPermS @(Const i)) -applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) applyPermShS = coerce (applyPermS @SNat) -stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) - => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a stranspose perm sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) , Refl <- lemCommMapJustTakeLen perm (sshape sarr) , Refl <- lemCommMapJustDropLen perm (sshape sarr) , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh)) + , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a @@ -1969,7 +1973,7 @@ stoList1 = map sunScalar . stoListOuter sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a sfromListPrim sn l - | Refl <- X.lemAppNil @'[Just n] + | Refl <- lemAppNil @'[Just n] = let ssh = SUnknown () :!% ZKX xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr @@ -1989,7 +1993,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) srerankP sh sh2 f sarr@(Shaped arr) | Refl <- lemCommMapJustApp sh (Proxy @sh1) , Refl <- lemCommMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh)))) + = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) (shCvtSX sh2) (\a -> let Shaped r = f (Shaped a) in r) arr) @@ -2033,12 +2037,12 @@ sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) -stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs deleted file mode 100644 index 95fcfcf..0000000 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ /dev/null @@ -1,435 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Arith where - -import Control.Monad (forM, guard) -import qualified Data.Array.Internal as OI -import qualified Data.Array.Internal.RankedG as RG -import qualified Data.Array.Internal.RankedS as RS -import Data.Bits -import Data.Int -import Data.List (sort) -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.C.Types -import Foreign.Ptr -import Foreign.Storable (Storable) -import GHC.TypeLits -import Language.Haskell.TH -import System.IO.Unsafe - -import Data.Array.Nested.Internal.Arith.Foreign -import Data.Array.Nested.Internal.Arith.Lists - - -liftVEltwise1 :: Storable a - => SNat n - -> (VS.Vector a -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just prefixSz <- stridesDense sh strides = - let vec' = f (VS.slice offset prefixSz vec) - in RS.A (RG.A sh (OI.T strides 0 vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) - -liftVEltwise2 :: Storable a - => SNat n - -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -> RS.Array n a -liftVEltwise2 SNat f - arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) - arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 - | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs - | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of - (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars - let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2)) - in RS.A (RG.A sh1 (OI.T strides1 0 vec')) - (Just 1, Just n) -> -- scalar * dense - RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) - (Just n, Just 1) -> -- dense * scalar - RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) - (_, _) -> -- fallback case - RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) - --- | Given the shape vector and the stride vector, return whether this vector --- of strides uses a dense prefix of its backing array. If so, the number of --- elements in this prefix is returned. --- This excludes any offset. -stridesDense :: [Int] -> [Int] -> Maybe Int -stridesDense sh _ | any (<= 0) sh = Just 0 -stridesDense sh str = - -- sort dimensions on their stride, ascending, dropping any zero strides - case dropWhile ((== 0) . fst) (sort (zip str sh)) of - [] -> Just 1 - (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' - _ -> Nothing -- if the smallest stride is not 1, it will never be dense - where - -- Given size of currently densely covered region at beginning of the - -- array, the remaining shape vector and the corresponding remaining stride - -- vector, return whether this all together covers a dense prefix of the - -- array. If it does, return the number of elements in this prefix. - checkCover :: Int -> [Int] -> [Int] -> Maybe Int - checkCover block [] [] = Just block - checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' - checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" - -{-# NOINLINE vectorOp1 #-} -vectorOp1 :: forall a b. Storable a - => (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr b -> IO ()) - -> VS.Vector a -> VS.Vector a -vectorOp1 ptrconv f v = unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length v) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith v $ \pv -> - f (fromIntegral (VS.length v)) (ptrconv poutv) (ptrconv pv) - VS.unsafeFreeze outv - --- | If two vectors are given, assumes that they have the same length. -{-# NOINLINE vectorOp2 #-} -vectorOp2 :: forall a b. Storable a - => (a -> b) - -> (Ptr a -> Ptr b) - -> (a -> a -> a) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- sv - -> (Int64 -> Ptr b -> Ptr b -> b -> IO ()) -- vs - -> (Int64 -> Ptr b -> Ptr b -> Ptr b -> IO ()) -- vv - -> Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a -vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases - (Left x) (Left y) -> VS.singleton (fss x y) - - (Left x) (Right vy) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vy) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vy $ \pvy -> - fsv (fromIntegral (VS.length vy)) (ptrconv poutv) (valconv x) (ptrconv pvy) - VS.unsafeFreeze outv - - (Right vx) (Left y) -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - fvs (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (valconv y) - VS.unsafeFreeze outv - - (Right vx) (Right vy) - | VS.length vx == VS.length vy -> - unsafePerformIO $ do - outv <- VSM.unsafeNew (VS.length vx) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith vx $ \pvx -> - VS.unsafeWith vy $ \pvy -> - fvv (fromIntegral (VS.length vx)) (ptrconv poutv) (ptrconv pvx) (ptrconv pvy) - VS.unsafeFreeze outv - | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) - --- TODO: test all the weird cases of this function --- | Reduce along the inner dimension -{-# NOINLINE vectorRedInnerOp #-} -vectorRedInnerOp :: forall a b n. (Num a, Storable a) - => SNat n - -> (a -> b) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel - -> RS.Array (n + 1) a -> RS.Array n a -vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) - | null sh = error "unreachable" - | last sh <= 0 = RS.stretch (init sh) (RS.fromList (map (const 1) (init sh)) [0]) - | any (<= 0) (init sh) = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) - -- now the input array is nonempty - | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) - | last strides == 0 = - liftVEltwise1 sn - (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) - (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) - -- now there is useful work along the inner dimension - | otherwise = - let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those - (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) - ndimsF = length shF - in unsafePerformIO $ do - outv <- VSM.unsafeNew (product (init shF)) - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> - VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> - VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> - fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) - RS.fromVector (init sh) <$> VS.unsafeFreeze outv - -flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) - -> Int64 -> Ptr a -> Ptr a -> a -> IO () -flipOp f n out v s = f n out s v - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_binary_" ++ atCName arithtype - c_ss = varE (aboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype)) - cnamebase = "c_fbinary_" ++ atCName arithtype - c_ss = varE (afboNumOp arithop) - c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (auoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_unary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (auoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (afuoName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_funary_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (afuoEnum arithop))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] - return $ FunD name [Clause [] (NormalB body) []]]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - fmap concat . forM [minBound..maxBound] $ \arithop -> do - let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop))) - c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) - sequence [SigD name <$> - [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] - return $ FunD name [Clause [] (NormalB body) []]]) - --- This branch is ostensibly a runtime branch, but will (hopefully) be --- constant-folded away by GHC. -intWidBranch1 :: forall i n. (FiniteBits i, Storable i) - => (Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) - -> (SNat n -> RS.Array n i -> RS.Array n i) -intWidBranch1 f32 f64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) - | otherwise = error "Unsupported Int width" - -intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) - => (i -> i -> i) -- ss - -- int32 - -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv - -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) -intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) - | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) - | otherwise = error "Unsupported Int width" - -intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) - => -- int32 - (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel - -- int64 - -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) -intWidBranchRed fsc32 fred32 fsc64 fred64 sn - | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 - | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 - | otherwise = error "Unsupported Int width" - -class NumElt a where - numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltMul :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a - numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a - numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a - numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a - -instance NumElt Int32 where - numEltAdd = addVectorInt32 - numEltSub = subVectorInt32 - numEltMul = mulVectorInt32 - numEltNeg = negVectorInt32 - numEltAbs = absVectorInt32 - numEltSignum = signumVectorInt32 - numEltSum1Inner = sum1VectorInt32 - numEltProduct1Inner = product1VectorInt32 - -instance NumElt Int64 where - numEltAdd = addVectorInt64 - numEltSub = subVectorInt64 - numEltMul = mulVectorInt64 - numEltNeg = negVectorInt64 - numEltAbs = absVectorInt64 - numEltSignum = signumVectorInt64 - numEltSum1Inner = sum1VectorInt64 - numEltProduct1Inner = product1VectorInt64 - -instance NumElt Float where - numEltAdd = addVectorFloat - numEltSub = subVectorFloat - numEltMul = mulVectorFloat - numEltNeg = negVectorFloat - numEltAbs = absVectorFloat - numEltSignum = signumVectorFloat - numEltSum1Inner = sum1VectorFloat - numEltProduct1Inner = product1VectorFloat - -instance NumElt Double where - numEltAdd = addVectorDouble - numEltSub = subVectorDouble - numEltMul = mulVectorDouble - numEltNeg = negVectorDouble - numEltAbs = absVectorDouble - numEltSignum = signumVectorDouble - numEltSum1Inner = sum1VectorDouble - numEltProduct1Inner = product1VectorDouble - -instance NumElt Int where - numEltAdd = intWidBranch2 @Int (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @Int (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @Int (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed @Int - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) - numEltProduct1Inner = intWidBranchRed @Int - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -instance NumElt CInt where - numEltAdd = intWidBranch2 @CInt (+) - (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD)) - (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD)) - numEltSub = intWidBranch2 @CInt (-) - (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB)) - (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB)) - numEltMul = intWidBranch2 @CInt (*) - (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL)) - (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL)) - numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG)) - numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS)) - numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM)) - numEltSum1Inner = intWidBranchRed @CInt - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1)) - numEltProduct1Inner = intWidBranchRed @CInt - (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) - (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) - -class FloatElt a where - floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltPow :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltLogbase :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a - floatEltRecip :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSqrt :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsin :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcos :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtan :: SNat n -> RS.Array n a -> RS.Array n a - floatEltSinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltCosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltTanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAsinh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAcosh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltAtanh :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1p :: SNat n -> RS.Array n a -> RS.Array n a - floatEltExpm1 :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1pexp :: SNat n -> RS.Array n a -> RS.Array n a - floatEltLog1mexp :: SNat n -> RS.Array n a -> RS.Array n a - -instance FloatElt Float where - floatEltDiv = divVectorFloat - floatEltPow = powVectorFloat - floatEltLogbase = logbaseVectorFloat - floatEltRecip = recipVectorFloat - floatEltExp = expVectorFloat - floatEltLog = logVectorFloat - floatEltSqrt = sqrtVectorFloat - floatEltSin = sinVectorFloat - floatEltCos = cosVectorFloat - floatEltTan = tanVectorFloat - floatEltAsin = asinVectorFloat - floatEltAcos = acosVectorFloat - floatEltAtan = atanVectorFloat - floatEltSinh = sinhVectorFloat - floatEltCosh = coshVectorFloat - floatEltTanh = tanhVectorFloat - floatEltAsinh = asinhVectorFloat - floatEltAcosh = acoshVectorFloat - floatEltAtanh = atanhVectorFloat - floatEltLog1p = log1pVectorFloat - floatEltExpm1 = expm1VectorFloat - floatEltLog1pexp = log1pexpVectorFloat - floatEltLog1mexp = log1mexpVectorFloat - -instance FloatElt Double where - floatEltDiv = divVectorDouble - floatEltPow = powVectorDouble - floatEltLogbase = logbaseVectorDouble - floatEltRecip = recipVectorDouble - floatEltExp = expVectorDouble - floatEltLog = logVectorDouble - floatEltSqrt = sqrtVectorDouble - floatEltSin = sinVectorDouble - floatEltCos = cosVectorDouble - floatEltTan = tanVectorDouble - floatEltAsin = asinVectorDouble - floatEltAcos = acosVectorDouble - floatEltAtan = atanVectorDouble - floatEltSinh = sinhVectorDouble - floatEltCosh = coshVectorDouble - floatEltTanh = tanhVectorDouble - floatEltAsinh = asinhVectorDouble - floatEltAcosh = acoshVectorDouble - floatEltAtanh = atanhVectorDouble - floatEltLog1p = log1pVectorDouble - floatEltExpm1 = expm1VectorDouble - floatEltLog1pexp = log1pexpVectorDouble - floatEltLog1mexp = log1mexpVectorDouble diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs deleted file mode 100644 index ac83188..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs +++ /dev/null @@ -1,55 +0,0 @@ -{-# LANGUAGE ForeignFunctionInterface #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Foreign where - -import Control.Monad -import Data.Int -import Data.Maybe -import Foreign.C.Types -import Foreign.Ptr -import Language.Haskell.TH - -import Data.Array.Nested.Internal.Arith.Lists - - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "binary_" ++ atCName arithtype - sequence $ catMaybes - [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "fbinary_" ++ atCName arithtype - sequence $ catMaybes - [Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_sv") (mkName ("c_" ++ base ++ "_sv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vv") (mkName ("c_" ++ base ++ "_vv")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - ,Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |]) - ]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "unary_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM floatTypesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "funary_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) - -$(fmap concat . forM typesList $ \arithtype -> do - let ttyp = conT (atType arithtype) - let base = "reduce_" ++ atCName arithtype - pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> - [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs deleted file mode 100644 index ce2836d..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Lists.hs +++ /dev/null @@ -1,78 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Internal.Arith.Lists where - -import Data.Char -import Data.Int -import Language.Haskell.TH - -import Data.Array.Nested.Internal.Arith.Lists.TH - - -data ArithType = ArithType - { atType :: Name -- ''Int32 - , atCName :: String -- "i32" - } - -floatTypesList :: [ArithType] -floatTypesList = - [ArithType ''Float "float" - ,ArithType ''Double "double" - ] - -typesList :: [ArithType] -typesList = - [ArithType ''Int32 "i32" - ,ArithType ''Int64 "i64" - ] - ++ floatTypesList - --- data ArithBOp = BO_ADD | BO_SUB | BO_MUL deriving (Show, Enum, Bounded) -$(genArithDataType Binop "ArithBOp") - -$(genArithNameFun Binop ''ArithBOp "aboName" (map toLower . drop 3)) -$(genArithEnumFun Binop ''ArithBOp "aboEnum") - -$(do clauses <- readArithLists Binop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "aboNumOp") <$> [t| ArithBOp -> Name |] - ,return $ FunD (mkName "aboNumOp") clauses]) - - --- data ArithFBOp = FB_DIV deriving (Show, Enum, Bounded) -$(genArithDataType FBinop "ArithFBOp") - -$(genArithNameFun FBinop ''ArithFBOp "afboName" (map toLower . drop 3)) -$(genArithEnumFun FBinop ''ArithFBOp "afboEnum") - -$(do clauses <- readArithLists FBinop - (\name _num hsop -> return (Clause [ConP (mkName name) [] []] - (NormalB (VarE 'mkName `AppE` LitE (StringL hsop))) - [])) - return - sequence [SigD (mkName "afboNumOp") <$> [t| ArithFBOp -> Name |] - ,return $ FunD (mkName "afboNumOp") clauses]) - - --- data ArithUOp = UO_NEG | UO_ABS | UO_SIGNUM | ... deriving (Show, Enum, Bounded) -$(genArithDataType Unop "ArithUOp") - -$(genArithNameFun Unop ''ArithUOp "auoName" (map toLower . drop 3)) -$(genArithEnumFun Unop ''ArithUOp "auoEnum") - - --- data ArithFUOp = FU_RECIP | ... deriving (Show, Enum, Bounded) -$(genArithDataType FUnop "ArithFUOp") - -$(genArithNameFun FUnop ''ArithFUOp "afuoName" (map toLower . drop 3)) -$(genArithEnumFun FUnop ''ArithFUOp "afuoEnum") - - --- data ArithRedOp = RO_SUM1 | RO_PRODUCT1 deriving (Show, Enum, Bounded) -$(genArithDataType Redop "ArithRedOp") - -$(genArithNameFun Redop ''ArithRedOp "aroName" (map toLower . drop 3)) -$(genArithEnumFun Redop ''ArithRedOp "aroEnum") diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs deleted file mode 100644 index 7142dfa..0000000 --- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs +++ /dev/null @@ -1,82 +0,0 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} -module Data.Array.Nested.Internal.Arith.Lists.TH where - -import Control.Monad -import Control.Monad.IO.Class -import Data.Maybe -import Foreign.C.Types -import Language.Haskell.TH -import Language.Haskell.TH.Syntax -import Text.Read - - -data OpKind = Binop | FBinop | Unop | FUnop | Redop - deriving (Show, Eq) - -readArithLists :: OpKind - -> (String -> Int -> String -> Q a) - -> ([a] -> Q r) - -> Q r -readArithLists targetkind fop fcombine = do - addDependentFile "cbits/arith_lists.h" - lns <- liftIO $ lines <$> readFile "cbits/arith_lists.h" - - mvals <- forM lns $ \line -> do - if null (dropWhile (== ' ') line) - then return Nothing - else do let (kind, name, num, aux) = parseLine line - if kind == targetkind - then Just <$> fop name num aux - else return Nothing - - fcombine (catMaybes mvals) - where - parseLine s0 - | ("LIST_", s1) <- splitAt 5 s0 - , (kindstr, '(' : s2) <- break (== '(') s1 - , (f1, ',' : s3) <- parseField s2 - , (f2, ',' : s4) <- parseField s3 - , (f3, ')' : _) <- parseField s4 - , Just kind <- parseKind kindstr - , let name = f1 - , Just num <- readMaybe f2 - , let aux = f3 - = (kind, name, num, aux) - | otherwise - = error $ "readArithLists: unrecognised line in cbits/arith_lists.h: " ++ show s0 - - parseField s = break (`elem` ",)") (dropWhile (== ' ') s) - - parseKind "BINOP" = Just Binop - parseKind "FBINOP" = Just FBinop - parseKind "UNOP" = Just Unop - parseKind "FUNOP" = Just FUnop - parseKind "REDOP" = Just Redop - parseKind _ = Nothing - -genArithDataType :: OpKind -> String -> Q [Dec] -genArithDataType kind dtname = do - cons <- readArithLists kind - (\name _num _ -> return $ NormalC (mkName name) []) - return - return [DataD [] (mkName dtname) [] Nothing cons [DerivClause Nothing [ConT ''Show, ConT ''Enum, ConT ''Bounded]]] - -genArithNameFun :: OpKind -> Name -> String -> (String -> String) -> Q [Dec] -genArithNameFun kind dtname funname nametrans = do - clauses <- readArithLists kind - (\name _num _ -> return (Clause [ConP (mkName name) [] []] - (NormalB (LitE (StringL (nametrans name)))) - [])) - return - return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''String) - ,FunD (mkName funname) clauses] - -genArithEnumFun :: OpKind -> Name -> String -> Q [Dec] -genArithEnumFun kind dtname funname = do - clauses <- readArithLists kind - (\name num _ -> return (Clause [ConP (mkName name) [] []] - (NormalB (LitE (IntegerL (fromIntegral num)))) - [])) - return - return [SigD (mkName funname) (ArrowT `AppT` ConT dtname `AppT` ConT ''CInt) - ,FunD (mkName funname) clauses] -- cgit v1.2.3-70-g09d2