{-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DataKinds #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Main where import qualified Data.Array.RankedS as OR import Data.Foldable (toList) import Data.Type.Equality import GHC.TypeLits import qualified GHC.TypeNats as TN import qualified Data.Array.Mixed as X import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS) import Data.Array.Nested import qualified Data.Array.Nested.Internal as I -- test framework stuff import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Tasty import Test.Tasty.Hedgehog import Debug.Trace genRank :: (forall n. SNat n -> PropertyT IO ()) -> PropertyT IO () genRank k = do rank <- forAll $ Gen.int (Range.linear 0 8) TN.withSomeSNat (fromIntegral rank) k genLowBiased :: RealFloat a => (a, a) -> Gen a genLowBiased (lo, hi) = do x <- Gen.realFloat (Range.linearFrac 0 1) return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) shuffleShR = \sh -> go (length (toList sh)) (toList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR go nbag bag (_ :$: sh) = do idx <- Gen.int (Range.linear 0 (nbag - 1)) let (dim, bag') = case splitAt idx bag of (pre, n : post) -> (n, pre ++ post) _ -> error "unreachable" (dim :$:) <$> go (nbag - 1) bag' sh genShR :: SNat n -> Gen (IShR n) genShR sn = do let n = fromSNat' sn targetSize <- Gen.int (Range.linear 0 (1000 * 3 ^ n)) let genDims :: SNat m -> Int -> Gen (IShR m) genDims SZ _ = return ZSR genDims (SS m) 0 = do dim <- Gen.int (Range.linear 0 20) dims <- genDims m 0 return (dim :$: dims) genDims (SS m) tgt = do dim <- Gen.frequency [(20 * n, round <$> genLowBiased @Double (2.0, max 2.0 (sqrt (fromIntegral tgt)))) ,(2 , return tgt) ,(4 , return 1) ,(1 , return 0)] dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) shuffleShR =<< genDims sn targetSize orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n a orSumOuter1 (sn@SNat :: SNat n) = let n = fromSNat' sn in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) main :: IO () main = defaultMain $ testGroup "Tests" [testGroup "C" [testGroup "sum" [testProperty "random" $ property $ genRank $ \outrank@(SNat @n) -> do let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank arr <- forAll $ OR.fromList @_ @(n + 1) (toList sh) <$> Gen.list (Range.singleton (product sh)) (Gen.realFloat (Range.linearFrac @Double 0 1)) let rarr = rfromOrthotope inrank arr annotateShow rarr Refl <- return $ I.lemRankReplicate outrank let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr let rhs = orSumOuter1 outrank arr annotateShow lhs annotateShow rhs lhs === rhs ] ] ]