From 92902c4f66db111b439f3b7eba9de50ad7c73f7b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 3 Apr 2024 12:37:35 +0200 Subject: Reorganise, documentation --- src/Array.hs | 312 ----------------------------------------------------------- 1 file changed, 312 deletions(-) delete mode 100644 src/Array.hs (limited to 'src/Array.hs') diff --git a/src/Array.hs b/src/Array.hs deleted file mode 100644 index cbf04fc..0000000 --- a/src/Array.hs +++ /dev/null @@ -1,312 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -module Array where - -import qualified Data.Array.RankedU as U -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU -import qualified GHC.TypeLits as GHC -import Unsafe.Coerce (unsafeCoerce) - -import Nats - - -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - - -type IxX :: [Maybe Nat] -> Type -data IxX sh where - IZX :: IxX '[] - (::@) :: Int -> IxX sh -> IxX (Just n : sh) - (::?) :: Int -> IxX sh -> IxX (Nothing : sh) -deriving instance Show (IxX sh) - -type StaticShapeX :: [Maybe Nat] -> Type -data StaticShapeX sh where - SZX :: StaticShapeX '[] - (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) - (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) -deriving instance Show (StaticShapeX sh) - -type KnownShapeX :: [Maybe Nat] -> Constraint -class KnownShapeX sh where - knownShapeX :: StaticShapeX sh -instance KnownShapeX '[] where - knownShapeX = SZX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = knownNat :$@ knownShapeX -instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :$? knownShapeX - -type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) - -type XArray :: [Maybe Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) - -zeroIdx :: StaticShapeX sh -> IxX sh -zeroIdx SZX = IZX -zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh -zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh - -zeroIdx' :: IxX sh -> IxX sh -zeroIdx' IZX = IZX -zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh -zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh - -ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') -ixAppend IZX idx' = idx' -ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' -ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' - -ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' -ixDrop sh IZX = sh -ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx -ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx - -ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') -ssxAppend SZX idx' = idx' -ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx' -ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx' - -shapeSize :: IxX sh -> Int -shapeSize IZX = 1 -shapeSize (n ::@ sh) = n * shapeSize sh -shapeSize (n ::? sh) = n * shapeSize sh - -fromLinearIdx :: IxX sh -> Int -> IxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IxX sh -> Int -> (IxX sh, Int) - go IZX i = (IZX, i) - go (n ::@ sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali ::@ idx, upi) - go (n ::? sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali ::? idx, upi) - -toLinearIdx :: IxX sh -> IxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IxX sh -> IxX sh -> (Int, Int) - go IZX IZX = (0, 1) - go (n ::@ sh) (i ::@ ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - go (n ::? sh) (i ::? ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - -enumShape :: IxX sh -> [IxX sh] -enumShape = \sh -> go 0 sh id [] - where - go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a] - go _ IZX _ = id - go i (n ::@ sh) f - | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@)) - | otherwise = id - go i (n ::? sh) f - | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?)) - | otherwise = id - -shapeLshape :: IxX sh -> U.ShapeL -shapeLshape IZX = [] -shapeLshape (n ::@ sh) = n : shapeLshape sh -shapeLshape (n ::? sh) = n : shapeLshape sh - -ssxLength :: StaticShapeX sh -> Int -ssxLength SZX = 0 -ssxLength (_ :$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :$? ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] -ssxIotaFrom _ SZX = [] -ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh - -lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 - -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 - -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) -lemKnownNatRank IZX = Dict -lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict - -lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX SZX = Dict -lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict - -lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh -lemKnownShapeX SZX = Dict -lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict -lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict - -lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (n :$@ ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - , Dict <- snatKnown n - = Dict -lemAppKnownShapeX (() :$? ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict - -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh -shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) - where - go :: StaticShapeX sh' -> [Int] -> IxX sh' - go SZX [] = IZX - go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l - go (() :$? ssh) (n : l) = n ::? go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownNatRank sh - , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.fromVector (shapeLshape sh) v) - -toVector :: U.Unbox a => XArray sh a -> VU.Vector a -toVector (XArray arr) = U.toVector arr - -scalar :: U.Unbox a => a -> XArray '[] a -scalar = XArray . U.scalar - -unScalar :: U.Unbox a => XArray '[] a -> a -unScalar (XArray a) = U.unScalar a - -generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) - --- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . U.fromVector (shapeLshape sh) --- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) - -indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a -indexPartial (XArray arr) IZX = XArray arr -indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx -indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx - -index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in U.unScalar arr' - -append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a -append (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.append a b) - -rerank :: forall sh sh1 sh2 a b. - (U.Unbox a, U.Unbox b) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Dict <- gknownNat (Proxy @(Rank sh)) - , Dict <- lemKnownNatRankSSX ssh2 - , Dict <- gknownNat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) - (\a -> unXArray (f (XArray a))) - arr) - where - unXArray (XArray a) = a - -rerank2 :: forall sh sh1 sh2 a b c. - (U.Unbox a, U.Unbox b, U.Unbox c) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX ssh - , Dict <- gknownNat (Proxy @(Rank sh)) - , Dict <- lemKnownNatRankSSX ssh2 - , Dict <- gknownNat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) - (\a b -> unXArray (f (XArray a) (XArray b))) - arr1 arr2) - where - unXArray (XArray a) = a - --- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a -transpose perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.transpose perm arr) - -transpose2 :: forall sh1 sh2 a. - StaticShapeX sh1 -> StaticShapeX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (U.Unbox a, Num a) => XArray sh a -> a -sumFull (XArray arr) = U.sumA arr - -sumInner :: forall sh sh' a. (U.Unbox a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' - | Refl <- lemAppNil @sh - = rerank ssh ssh' SZX (scalar . sumFull) - -sumOuter :: forall sh sh' a. (U.Unbox a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' - | Refl <- lemAppNil @sh - = sumInner ssh' ssh . transpose2 ssh ssh' -- cgit v1.2.3-70-g09d2