diff options
| -rw-r--r-- | chad-fast.cabal | 8 | ||||
| -rw-r--r-- | misc/SparseLattice.hs | 177 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs | 2 |
3 files changed, 186 insertions, 1 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 1eef3ed..e0464d8 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -112,6 +112,14 @@ test-suite test default-language: Haskell2010 ghc-options: -Wall -threaded -rtsopts +test-suite sparse-lattice + type: exitcode-stdio-1.0 + main-is: SparseLattice.hs + build-depends: base, chad-fast, some + hs-source-dirs: misc + default-language: Haskell2010 + ghc-options: -Wall -threaded + benchmark bench type: exitcode-stdio-1.0 main-is: Main.hs diff --git a/misc/SparseLattice.hs b/misc/SparseLattice.hs new file mode 100644 index 0000000..b7cebdc --- /dev/null +++ b/misc/SparseLattice.hs @@ -0,0 +1,177 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module Main where + +import Control.Monad +import Data.Type.Equality +import Data.Some +import System.Exit + +import CHAD.AST.Types +import CHAD.AST.Pretty +import CHAD.AST.Sparse +import CHAD.Data + + +genTypes :: Int -> [Some SMTy] +genTypes dep | dep < 0 = [] +genTypes dep = + [Some SMTNil, Some (SMTScal STF64)] ++ + concat [[Some (SMTMaybe t), Some (SMTArr (SS SZ) t)] + | Some t <- genTypes (dep - 1)] ++ + concat [[Some (SMTPair a b), Some (SMTLEither a b)] + | Some a <- genTypes (dep - 1), Some b <- genTypes (dep - 1)] + +genSparse :: SMTy t -> [Some (Sparse t)] +genSparse SMTNil = + [Some SpAbsent] +genSparse (SMTPair a b) = + [Some SpAbsent] ++ + concat [[Some (SpPair s1 s2), Some (SpSparse (SpPair s1 s2))] + | Some s1 <- genSparse a, Some s2 <- genSparse b] +genSparse (SMTLEither a b) = + [Some SpAbsent] ++ [Some (SpLEither s1 s2) | Some s1 <- genSparse a, Some s2 <- genSparse b] +genSparse (SMTMaybe t) = + [Some SpAbsent] ++ [Some (SpMaybe s) | Some s <- genSparse t] +genSparse (SMTArr _ t) = + [Some SpAbsent] ++ concat [[Some (SpArr s), Some (SpSparse (SpArr s))] | Some s <- genSparse t] +genSparse (SMTScal _) = + [Some SpAbsent, Some SpScal] + +computeJoin :: SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 -> Sparse t3 t1 -> Sparse t3 t2 -> r) -> r +computeJoin t SpAbsent s2 k = k s2 SpAbsent (spDense (applySparse s2 t)) +computeJoin t s1 SpAbsent k = k s1 (spDense (applySparse s1 t)) SpAbsent +computeJoin t (SpSparse s1) (SpSparse s2) k = + computeJoin t s1 s2 $ \s3 s13 s23 -> + k (SpSparse s3) (SpMaybe s13) (SpMaybe s23) +computeJoin t (SpSparse s1) s2 k = + computeJoin t s1 s2 $ \s3 s13 s23 -> + k s3 (SpSparse s13) s23 +computeJoin t s1 (SpSparse s2) k = + computeJoin t s1 s2 $ \s3 s13 s23 -> + k s3 s13 (SpSparse s23) +computeJoin (SMTPair a b) (SpPair sa1 sb1) (SpPair sa2 sb2) k = + computeJoin a sa1 sa2 $ \sa3 sa13 sa23 -> + computeJoin b sb1 sb2 $ \sb3 sb13 sb23 -> + k (SpPair sa3 sb3) (SpPair sa13 sb13) (SpPair sa23 sb23) +computeJoin (SMTLEither a b) (SpLEither sa1 sb1) (SpLEither sa2 sb2) k = + computeJoin a sa1 sa2 $ \sa3 sa13 sa23 -> + computeJoin b sb1 sb2 $ \sb3 sb13 sb23 -> + k (SpLEither sa3 sb3) (SpLEither sa13 sb13) (SpLEither sa23 sb23) +computeJoin (SMTMaybe t) (SpMaybe s1) (SpMaybe s2) k = + computeJoin t s1 s2 $ \s3 s13 s23 -> + k (SpMaybe s3) (SpMaybe s13) (SpMaybe s23) +computeJoin (SMTArr _ t) (SpArr s1) (SpArr s2) k = + computeJoin t s1 s2 $ \s3 s13 s23 -> + k (SpArr s3) (SpArr s13) (SpArr s23) +computeJoin (SMTScal _) SpScal SpScal k = k SpScal SpScal SpScal + +-- Checks that the sparsity structures are equal, not just the returned types +spEqual :: Sparse t t1 -> Sparse t t2 -> Maybe (t1 :~: t2) +spEqual (SpSparse s1) (SpSparse s2) | Just Refl <- spEqual s1 s2 = Just Refl +spEqual (SpSparse _) _ = Nothing +spEqual _ (SpSparse _) = Nothing +spEqual SpAbsent SpAbsent = Just Refl +spEqual SpAbsent _ = Nothing +spEqual _ SpAbsent = Nothing +spEqual (SpPair sa1 sb1) (SpPair sa2 sb2) + | Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl + | otherwise = Nothing +spEqual (SpLEither sa1 sb1) (SpLEither sa2 sb2) + | Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl + | otherwise = Nothing +spEqual (SpMaybe s1) (SpMaybe s2) + | Just Refl <- spEqual s1 s2 = Just Refl + | otherwise = Nothing +spEqual (SpArr s1) (SpArr s2) + | Just Refl <- spEqual s1 s2 = Just Refl + | otherwise = Nothing +spEqual SpScal SpScal = Just Refl + +stupid :: Sparse t t' -> Bool +stupid (SpMaybe (SpSparse _)) = True +stupid (SpMaybe SpAbsent) = True + +stupid (SpSparse s) = stupid s +stupid SpAbsent = False +stupid (SpPair s1 s2) = stupid s1 || stupid s2 +stupid (SpLEither s1 s2) = stupid s1 || stupid s2 +stupid (SpArr s) = stupid s +stupid (SpMaybe s) = stupid s +stupid SpScal = False + +-- derive big small +derive :: SMTy t1 -> SMTy t2 -> Maybe (Sparse t1 t2) +-- dense recursion +derive (SMTScal t) (SMTScal t') + | Just Refl <- testEquality t t' + = Just SpScal +derive (SMTPair a b) (SMTPair a' b') + | Just s1 <- derive a a' + , Just s2 <- derive b b' + = Just (SpPair s1 s2) +derive (SMTLEither a b) (SMTLEither a' b') + | Just s1 <- derive a a' + , Just s2 <- derive b b' + = Just (SpLEither s1 s2) +-- derive (SMTMaybe t) (SMTMaybe t') = _ -- no M in the paper +derive (SMTArr n t) (SMTArr n' t') + | Just Refl <- testEquality n n' + , Just s <- derive t t' + = Just (SpArr s) +-- sparsity +derive _ SMTNil = Just SpAbsent +derive t (SMTMaybe t') + | Just s <- derive t t' + = Just (SpSparse s) +-- remainder cannot work +derive SMTScal{} _ = Nothing +derive SMTNil _ = Nothing +derive SMTPair{} _ = Nothing +derive SMTLEither{} _ = Nothing +derive SMTMaybe{} _ = Nothing +derive SMTArr{} _ = Nothing + +typesEqual :: SMTy a -> SMTy b -> Bool +typesEqual a b = case testEquality a b of + Just Refl -> True + Nothing -> False + +main :: IO () +main = do + putStrLn "# Antisymmetry" + forM_ (genTypes 2) $ \(Some ty1) -> + forM_ (genTypes 2) $ \(Some ty2) -> + when (not (typesEqual ty1 ty2)) $ + case (derive ty1 ty2, derive ty2 ty1) of + (Just s1, Just s2) -> do + putStrLn $ ppSMTy 0 ty1 ++ " ≤ " ++ ppSMTy 0 ty2 ++ ": " ++ show s2 + putStrLn $ ppSMTy 0 ty2 ++ " ≤ " ++ ppSMTy 0 ty1 ++ ": " ++ show s1 + exitFailure @() + _ -> return () + + putStrLn "# Plus = join" + forM_ (genTypes 2) $ \(Some ty) -> + forM_ (genSparse ty) $ \(Some s1) -> + when (not (stupid s1)) $ + forM_ (genSparse ty) $ \(Some s2) -> + when (not (stupid s2)) $ + sparsePlusS SF SF ty s1 s2 $ \s3 _ _ _ -> + computeJoin ty s1 s2 $ \sJ _s1J _s2J -> + case spEqual s3 sJ of + Just Refl -> return () + Nothing -> do + putStrLn $ "type = " ++ ppSMTy 0 ty + putStrLn $ "s1 = " ++ show s1 + putStrLn $ "s2 = " ++ show s2 + putStrLn $ "plus -> " ++ show s3 + putStrLn $ "join -> " ++ show sJ + exitFailure @() + + -- print (length (genTypes 0)) + -- print (length (genTypes 1)) + -- print (length (genTypes 2)) + -- print (length (genTypes 3)) diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index 9ddcb35..aabe6d6 100644 --- a/src/CHAD/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -5,7 +5,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} -module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where +module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, ppSparse, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) |
