diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 34 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape/Internal.hs | 50 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 11 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 5 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 5 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 5 |
6 files changed, 75 insertions, 35 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 8aa5a77..5a45a09 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -16,6 +16,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -44,6 +45,7 @@ import GHC.TypeLits import GHC.TypeLits.Orphans () #endif +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types @@ -276,33 +278,6 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -{-# INLINEABLE ixxFromLinear #-} -ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i -ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times - let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) - in \i -> - if i < 0 then outrange sh i else - case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check - (ZSX, _) | i > 0 -> outrange sh i - | otherwise -> ZIX - (n :$% sh', suff : suffs) -> - let (q, r) = i `quotRem` suff - in if q >= fromSMayNat' n then outrange sh i else - fromIntegral q :.% fromLin sh' suffs r - _ -> error "impossible" - where - fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i - fromLin ZSX _ !_ = ZIX - fromLin (_ :$% sh') (suff : suffs) i = - let (q, r) = i `quotRem` suff -- suff == shrSize sh' - in fromIntegral q :.% fromLin sh' suffs r - fromLin _ _ _ = error "impossible" - - {-# NOINLINE outrange #-} - outrange :: IShX sh -> Int -> a - outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - ixxToLinear :: IShX sh -> IIxX sh -> Int ixxToLinear = \sh i -> fst (go sh i) where @@ -684,3 +659,8 @@ instance KnownShX sh => IsList (ShX sh Int) where type Item (ShX sh Int) = Int fromList = shxFromList (knownShX @sh) toList = shxToList + +-- This needs to be at the bottom of the file to not split the file into +-- pieces; some of the shape/index stuff refers to StaticShX. +$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |]) +{-# INLINEABLE ixxFromLinear #-} diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs new file mode 100644 index 0000000..cf44522 --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape/Internal.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Nested.Mixed.Shape.Internal where + +import Language.Haskell.TH + + +-- | A TH stub function to avoid having to write the same code three times for +-- the three kinds of shapes. +ixFromLinearStub :: String + -> TypeQ -> TypeQ + -> PatQ -> (PatQ -> PatQ -> PatQ) + -> ExpQ -> ExpQ + -> ExpQ + -> DecsQ +ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do + let fname = mkName fname' + typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |] + + locals <- [d| + fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i + fromLin $zshC _ !_ = $ixz + fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in $ixcons (fromIntegral q) (fromLin sh' suffs r) + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: $ishty sh -> Int -> a + outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" |] + + body <- [| + \sh -> -- give this function arity 1 so that 'suffixes' is shared when + -- it's called many times + let suffixes = drop 1 (scanr (*) 1 ($shtolist sh)) + in \i -> + if i < 0 then outrange sh i else + case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check + ($zshC, _) | i > 0 -> outrange sh i + | otherwise -> $ixz + ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + $ixcons (fromIntegral q) (fromLin sh' suffs r) + _ -> error "impossible" |] + + return [SigD fname typesig + ,FunD fname [Clause [] (NormalB body) locals]] diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index b77b529..d687983 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -70,16 +70,13 @@ rgenerate sh f , Refl <- lemRankReplicate sn = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) --- TODO: this would be shorter and faster written with rfromVector, --- but unfortunately we don't have ixrFromLinear +-- | See 'mgeneratePrim'. {-# INLINE rgeneratePrim #-} rgeneratePrim :: forall n a i. (PrimElt a, Num i) => IShR n -> (IxR n i -> a) -> Ranked n a -rgeneratePrim sh f - | sn@SNat <- shrRank sh - , Dict <- lemKnownReplicate sn - , Refl <- lemRankReplicate sn - = Ranked (mgeneratePrim (shxFromShR sh) (f . ixrFromIxX)) +rgeneratePrim sh f = + let g i = f (ixrFromLinear sh i) + in rfromVector sh $ VS.generate (shrSize sh) g -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. Elt a diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 9815c42..02d65b6 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -19,6 +19,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -43,6 +44,7 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types @@ -417,3 +419,6 @@ listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx listrCastWithName name _ _ = error $ name ++ ": ranks don't match" + +$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) +{-# INLINEABLE ixrFromLinear #-} diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 075549d..99ad590 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -72,10 +72,13 @@ sindexPartial sarr@(Shaped arr) idx = sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +-- | See 'mgeneratePrim'. {-# INLINE sgeneratePrim #-} sgeneratePrim :: forall sh a i. (PrimElt a, Num i) => ShS sh -> (IxS sh i -> a) -> Shaped sh a -sgeneratePrim sh f = Shaped (mgeneratePrim (shxFromShS sh) (f . ixsFromIxX sh)) +sgeneratePrim sh f = + let g i = f (ixsFromLinear sh i) + in sfromVector sh $ VS.generate (shsSize sh) g -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. Elt a diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0a4c1b9..a237b88 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -17,6 +17,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -44,6 +45,7 @@ import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -465,3 +467,6 @@ instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList = shsFromList (knownShS @sh) toList = shsToList + +$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) +{-# INLINEABLE ixsFromLinear #-} |
