{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Gen where import Data.ByteString qualified as BS import Data.Foldable (toList) import Data.Type.Equality import Data.Vector.Storable qualified as VS import Foreign import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Mixed.Permutation import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested import Hedgehog import Hedgehog.Gen qualified as Gen import Hedgehog.Range qualified as Range import System.Random qualified as Random import Util genRank :: Monad m => (forall n. SNat n -> PropertyT m ()) -> PropertyT m () genRank k = do rank <- forAll $ Gen.int (Range.linear 0 8) TN.withSomeSNat (fromIntegral rank) k genLowBiased :: RealFloat a => (a, a) -> Gen a genLowBiased (lo, hi) = do x <- Gen.realFloat (Range.linearFrac 0 1) return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR go nbag bag (_ :$: sh) = do idx <- Gen.int (Range.linear 0 (nbag - 1)) let (dim, bag') = case splitAt idx bag of (pre, n : post) -> (n, pre ++ post) _ -> error "unreachable" (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) genShR sn = do let n = fromSNat' sn targetSize <- Gen.int (Range.linear 0 100_000) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do dim <- Gen.int (Range.linear 0 20) dims <- genDims m 0 return (dim :$: dims) genDims (SS m) tgt = do dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) ,(2 , return tgt) ,(4 , return 1) ,(1 , return 0)] dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) dims <- genDims sn targetSize let dimsL = toList dims maxdim = maximum dimsL cap = binarySearch (`div` 2) 1 maxdim (\cap' -> product (min cap' <$> dimsL) <= targetSize) shuffleShR (min cap <$> dims) genStorables :: forall a. Storable a => Range Int -> (Word64 -> a) -> GenT IO (VS.Vector a) genStorables rng f = do n <- Gen.int rng seed <- Gen.resize 99 $ Gen.int Range.linearBounded let gen0 = Random.mkStdGen seed (bs, _) = Random.genByteString (8 * n) gen0 let readW64 i = sum (zipWith (*) (iterate (*256) 1) [fromIntegral (bs `BS.index` (8 * i + j)) | j <- [0..7]]) return $ VS.generate n (f . readW64) genStaticShX :: Monad m => (forall sh. StaticShX sh -> PropertyT m ()) -> PropertyT m () genStaticShX = \k -> genRank (\sn -> go sn k) where go :: Monad m => SNat n -> (forall sh. StaticShX sh -> PropertyT m ()) -> PropertyT m () go SZ k = k ZKX go (SS n) k = genItem $ \item -> go n $ \ssh -> k (item :!% ssh) genItem :: Monad m => (forall n. SMayNat () SNat n -> PropertyT m ()) -> PropertyT m () genItem k = do b <- forAll Gen.bool if b then do n <- forAll $ Gen.frequency [(20, Gen.int (Range.linear 1 4)) ,(1, return 0)] TN.withSomeSNat (fromIntegral n) $ \sn -> k (SKnown sn) else k (SUnknown ()) genShX :: StaticShX sh -> Gen (IShX sh) genShX ZKX = return ZSX genShX (SKnown sn :!% ssh) = (SKnown sn :$%) <$> genShX ssh genShX (SUnknown () :!% ssh) = do dim <- Gen.int (Range.linear 1 4) (SUnknown dim :$%) <$> genShX ssh genPermR :: Int -> Gen PermR genPermR n = Gen.shuffle [0 .. n-1] genPerm :: Monad m => SNat n -> (forall p. (IsPermutation p, Rank p ~ n) => Perm p -> PropertyT m r) -> PropertyT m r genPerm n@SNat k = do list <- forAll $ genPermR (fromSNat' n) permFromList list $ \perm -> do case permCheckPermutation perm $ case sameNat' (permLengthSNat perm) n of Just Refl -> Just (k perm) Nothing -> Nothing of Just (Just act) -> act _ -> error ""