{-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Main where import Control.Monad import Data.Type.Equality import Data.Some import System.Exit import CHAD.AST.Types import CHAD.AST.Pretty import CHAD.AST.Sparse import CHAD.Data genTypes :: Int -> [Some SMTy] genTypes dep | dep < 0 = [] genTypes dep = [Some SMTNil, Some (SMTScal STF64)] ++ concat [[Some (SMTMaybe t), Some (SMTArr (SS SZ) t)] | Some t <- genTypes (dep - 1)] ++ concat [[Some (SMTPair a b), Some (SMTLEither a b)] | Some a <- genTypes (dep - 1), Some b <- genTypes (dep - 1)] genSparse :: SMTy t -> [Some (Sparse t)] genSparse SMTNil = [Some SpAbsent] genSparse (SMTPair a b) = [Some SpAbsent] ++ concat [[Some (SpPair s1 s2), Some (SpSparse (SpPair s1 s2))] | Some s1 <- genSparse a, Some s2 <- genSparse b] genSparse (SMTLEither a b) = [Some SpAbsent] ++ [Some (SpLEither s1 s2) | Some s1 <- genSparse a, Some s2 <- genSparse b] genSparse (SMTMaybe t) = [Some SpAbsent] ++ [Some (SpMaybe s) | Some s <- genSparse t] genSparse (SMTArr _ t) = [Some SpAbsent] ++ concat [[Some (SpArr s), Some (SpSparse (SpArr s))] | Some s <- genSparse t] genSparse (SMTScal _) = [Some SpAbsent, Some SpScal] computeJoin :: SMTy t -> Sparse t t1 -> Sparse t t2 -> (forall t3. Sparse t t3 -> Sparse t3 t1 -> Sparse t3 t2 -> r) -> r computeJoin t SpAbsent s2 k = k s2 SpAbsent (spDense (applySparse s2 t)) computeJoin t s1 SpAbsent k = k s1 (spDense (applySparse s1 t)) SpAbsent computeJoin t (SpSparse s1) (SpSparse s2) k = computeJoin t s1 s2 $ \s3 s13 s23 -> k (SpSparse s3) (SpMaybe s13) (SpMaybe s23) computeJoin t (SpSparse s1) s2 k = computeJoin t s1 s2 $ \s3 s13 s23 -> k s3 (SpSparse s13) s23 computeJoin t s1 (SpSparse s2) k = computeJoin t s1 s2 $ \s3 s13 s23 -> k s3 s13 (SpSparse s23) computeJoin (SMTPair a b) (SpPair sa1 sb1) (SpPair sa2 sb2) k = computeJoin a sa1 sa2 $ \sa3 sa13 sa23 -> computeJoin b sb1 sb2 $ \sb3 sb13 sb23 -> k (SpPair sa3 sb3) (SpPair sa13 sb13) (SpPair sa23 sb23) computeJoin (SMTLEither a b) (SpLEither sa1 sb1) (SpLEither sa2 sb2) k = computeJoin a sa1 sa2 $ \sa3 sa13 sa23 -> computeJoin b sb1 sb2 $ \sb3 sb13 sb23 -> k (SpLEither sa3 sb3) (SpLEither sa13 sb13) (SpLEither sa23 sb23) computeJoin (SMTMaybe t) (SpMaybe s1) (SpMaybe s2) k = computeJoin t s1 s2 $ \s3 s13 s23 -> k (SpMaybe s3) (SpMaybe s13) (SpMaybe s23) computeJoin (SMTArr _ t) (SpArr s1) (SpArr s2) k = computeJoin t s1 s2 $ \s3 s13 s23 -> k (SpArr s3) (SpArr s13) (SpArr s23) computeJoin (SMTScal _) SpScal SpScal k = k SpScal SpScal SpScal -- Checks that the sparsity structures are equal, not just the returned types spEqual :: Sparse t t1 -> Sparse t t2 -> Maybe (t1 :~: t2) spEqual (SpSparse s1) (SpSparse s2) | Just Refl <- spEqual s1 s2 = Just Refl spEqual (SpSparse _) _ = Nothing spEqual _ (SpSparse _) = Nothing spEqual SpAbsent SpAbsent = Just Refl spEqual SpAbsent _ = Nothing spEqual _ SpAbsent = Nothing spEqual (SpPair sa1 sb1) (SpPair sa2 sb2) | Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl | otherwise = Nothing spEqual (SpLEither sa1 sb1) (SpLEither sa2 sb2) | Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl | otherwise = Nothing spEqual (SpMaybe s1) (SpMaybe s2) | Just Refl <- spEqual s1 s2 = Just Refl | otherwise = Nothing spEqual (SpArr s1) (SpArr s2) | Just Refl <- spEqual s1 s2 = Just Refl | otherwise = Nothing spEqual SpScal SpScal = Just Refl stupid :: Sparse t t' -> Bool stupid (SpMaybe (SpSparse _)) = True stupid (SpMaybe SpAbsent) = True stupid (SpSparse s) = stupid s stupid SpAbsent = False stupid (SpPair s1 s2) = stupid s1 || stupid s2 stupid (SpLEither s1 s2) = stupid s1 || stupid s2 stupid (SpArr s) = stupid s stupid (SpMaybe s) = stupid s stupid SpScal = False -- derive big small derive :: SMTy t1 -> SMTy t2 -> Maybe (Sparse t1 t2) -- dense recursion derive (SMTScal t) (SMTScal t') | Just Refl <- testEquality t t' = Just SpScal derive (SMTPair a b) (SMTPair a' b') | Just s1 <- derive a a' , Just s2 <- derive b b' = Just (SpPair s1 s2) derive (SMTLEither a b) (SMTLEither a' b') | Just s1 <- derive a a' , Just s2 <- derive b b' = Just (SpLEither s1 s2) -- derive (SMTMaybe t) (SMTMaybe t') = _ -- no M in the paper derive (SMTArr n t) (SMTArr n' t') | Just Refl <- testEquality n n' , Just s <- derive t t' = Just (SpArr s) -- sparsity derive _ SMTNil = Just SpAbsent derive t (SMTMaybe t') | Just s <- derive t t' = Just (SpSparse s) -- remainder cannot work derive SMTScal{} _ = Nothing derive SMTNil _ = Nothing derive SMTPair{} _ = Nothing derive SMTLEither{} _ = Nothing derive SMTMaybe{} _ = Nothing derive SMTArr{} _ = Nothing typesEqual :: SMTy a -> SMTy b -> Bool typesEqual a b = case testEquality a b of Just Refl -> True Nothing -> False main :: IO () main = do putStrLn "# Antisymmetry" forM_ (genTypes 2) $ \(Some ty1) -> forM_ (genTypes 2) $ \(Some ty2) -> when (not (typesEqual ty1 ty2)) $ case (derive ty1 ty2, derive ty2 ty1) of (Just s1, Just s2) -> do putStrLn $ ppSMTy 0 ty1 ++ " ≤ " ++ ppSMTy 0 ty2 ++ ": " ++ show s2 putStrLn $ ppSMTy 0 ty2 ++ " ≤ " ++ ppSMTy 0 ty1 ++ ": " ++ show s1 exitFailure @() _ -> return () putStrLn "# Plus = join" forM_ (genTypes 2) $ \(Some ty) -> forM_ (genSparse ty) $ \(Some s1) -> when (not (stupid s1)) $ forM_ (genSparse ty) $ \(Some s2) -> when (not (stupid s2)) $ sparsePlusS SF SF ty s1 s2 $ \s3 _ _ _ -> computeJoin ty s1 s2 $ \sJ _s1J _s2J -> case spEqual s3 sJ of Just Refl -> return () Nothing -> do putStrLn $ "type = " ++ ppSMTy 0 ty putStrLn $ "s1 = " ++ show s1 putStrLn $ "s2 = " ++ show s2 putStrLn $ "plus -> " ++ show s3 putStrLn $ "join -> " ++ show sJ exitFailure @() -- print (length (genTypes 0)) -- print (length (genTypes 1)) -- print (length (genTypes 2)) -- print (length (genTypes 3))