aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal8
-rw-r--r--misc/SparseLattice.hs177
-rw-r--r--src/CHAD/AST/Pretty.hs2
3 files changed, 186 insertions, 1 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 1eef3ed..e0464d8 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -112,6 +112,14 @@ test-suite test
default-language: Haskell2010
ghc-options: -Wall -threaded -rtsopts
+test-suite sparse-lattice
+ type: exitcode-stdio-1.0
+ main-is: SparseLattice.hs
+ build-depends: base, chad-fast, some
+ hs-source-dirs: misc
+ default-language: Haskell2010
+ ghc-options: -Wall -threaded
+
benchmark bench
type: exitcode-stdio-1.0
main-is: Main.hs
diff --git a/misc/SparseLattice.hs b/misc/SparseLattice.hs
new file mode 100644
index 0000000..b7cebdc
--- /dev/null
+++ b/misc/SparseLattice.hs
@@ -0,0 +1,177 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module Main where
+
+import Control.Monad
+import Data.Type.Equality
+import Data.Some
+import System.Exit
+
+import CHAD.AST.Types
+import CHAD.AST.Pretty
+import CHAD.AST.Sparse
+import CHAD.Data
+
+
+genTypes :: Int -> [Some SMTy]
+genTypes dep | dep < 0 = []
+genTypes dep =
+ [Some SMTNil, Some (SMTScal STF64)] ++
+ concat [[Some (SMTMaybe t), Some (SMTArr (SS SZ) t)]
+ | Some t <- genTypes (dep - 1)] ++
+ concat [[Some (SMTPair a b), Some (SMTLEither a b)]
+ | Some a <- genTypes (dep - 1), Some b <- genTypes (dep - 1)]
+
+genSparse :: SMTy t -> [Some (Sparse t)]
+genSparse SMTNil =
+ [Some SpAbsent]
+genSparse (SMTPair a b) =
+ [Some SpAbsent] ++
+ concat [[Some (SpPair s1 s2), Some (SpSparse (SpPair s1 s2))]
+ | Some s1 <- genSparse a, Some s2 <- genSparse b]
+genSparse (SMTLEither a b) =
+ [Some SpAbsent] ++ [Some (SpLEither s1 s2) | Some s1 <- genSparse a, Some s2 <- genSparse b]
+genSparse (SMTMaybe t) =
+ [Some SpAbsent] ++ [Some (SpMaybe s) | Some s <- genSparse t]
+genSparse (SMTArr _ t) =
+ [Some SpAbsent] ++ concat [[Some (SpArr s), Some (SpSparse (SpArr s))] | Some s <- genSparse t]
+genSparse (SMTScal _) =
+ [Some SpAbsent, Some SpScal]
+
+computeJoin :: SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3 -> Sparse t3 t1 -> Sparse t3 t2 -> r) -> r
+computeJoin t SpAbsent s2 k = k s2 SpAbsent (spDense (applySparse s2 t))
+computeJoin t s1 SpAbsent k = k s1 (spDense (applySparse s1 t)) SpAbsent
+computeJoin t (SpSparse s1) (SpSparse s2) k =
+ computeJoin t s1 s2 $ \s3 s13 s23 ->
+ k (SpSparse s3) (SpMaybe s13) (SpMaybe s23)
+computeJoin t (SpSparse s1) s2 k =
+ computeJoin t s1 s2 $ \s3 s13 s23 ->
+ k s3 (SpSparse s13) s23
+computeJoin t s1 (SpSparse s2) k =
+ computeJoin t s1 s2 $ \s3 s13 s23 ->
+ k s3 s13 (SpSparse s23)
+computeJoin (SMTPair a b) (SpPair sa1 sb1) (SpPair sa2 sb2) k =
+ computeJoin a sa1 sa2 $ \sa3 sa13 sa23 ->
+ computeJoin b sb1 sb2 $ \sb3 sb13 sb23 ->
+ k (SpPair sa3 sb3) (SpPair sa13 sb13) (SpPair sa23 sb23)
+computeJoin (SMTLEither a b) (SpLEither sa1 sb1) (SpLEither sa2 sb2) k =
+ computeJoin a sa1 sa2 $ \sa3 sa13 sa23 ->
+ computeJoin b sb1 sb2 $ \sb3 sb13 sb23 ->
+ k (SpLEither sa3 sb3) (SpLEither sa13 sb13) (SpLEither sa23 sb23)
+computeJoin (SMTMaybe t) (SpMaybe s1) (SpMaybe s2) k =
+ computeJoin t s1 s2 $ \s3 s13 s23 ->
+ k (SpMaybe s3) (SpMaybe s13) (SpMaybe s23)
+computeJoin (SMTArr _ t) (SpArr s1) (SpArr s2) k =
+ computeJoin t s1 s2 $ \s3 s13 s23 ->
+ k (SpArr s3) (SpArr s13) (SpArr s23)
+computeJoin (SMTScal _) SpScal SpScal k = k SpScal SpScal SpScal
+
+-- Checks that the sparsity structures are equal, not just the returned types
+spEqual :: Sparse t t1 -> Sparse t t2 -> Maybe (t1 :~: t2)
+spEqual (SpSparse s1) (SpSparse s2) | Just Refl <- spEqual s1 s2 = Just Refl
+spEqual (SpSparse _) _ = Nothing
+spEqual _ (SpSparse _) = Nothing
+spEqual SpAbsent SpAbsent = Just Refl
+spEqual SpAbsent _ = Nothing
+spEqual _ SpAbsent = Nothing
+spEqual (SpPair sa1 sb1) (SpPair sa2 sb2)
+ | Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl
+ | otherwise = Nothing
+spEqual (SpLEither sa1 sb1) (SpLEither sa2 sb2)
+ | Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl
+ | otherwise = Nothing
+spEqual (SpMaybe s1) (SpMaybe s2)
+ | Just Refl <- spEqual s1 s2 = Just Refl
+ | otherwise = Nothing
+spEqual (SpArr s1) (SpArr s2)
+ | Just Refl <- spEqual s1 s2 = Just Refl
+ | otherwise = Nothing
+spEqual SpScal SpScal = Just Refl
+
+stupid :: Sparse t t' -> Bool
+stupid (SpMaybe (SpSparse _)) = True
+stupid (SpMaybe SpAbsent) = True
+
+stupid (SpSparse s) = stupid s
+stupid SpAbsent = False
+stupid (SpPair s1 s2) = stupid s1 || stupid s2
+stupid (SpLEither s1 s2) = stupid s1 || stupid s2
+stupid (SpArr s) = stupid s
+stupid (SpMaybe s) = stupid s
+stupid SpScal = False
+
+-- derive big small
+derive :: SMTy t1 -> SMTy t2 -> Maybe (Sparse t1 t2)
+-- dense recursion
+derive (SMTScal t) (SMTScal t')
+ | Just Refl <- testEquality t t'
+ = Just SpScal
+derive (SMTPair a b) (SMTPair a' b')
+ | Just s1 <- derive a a'
+ , Just s2 <- derive b b'
+ = Just (SpPair s1 s2)
+derive (SMTLEither a b) (SMTLEither a' b')
+ | Just s1 <- derive a a'
+ , Just s2 <- derive b b'
+ = Just (SpLEither s1 s2)
+-- derive (SMTMaybe t) (SMTMaybe t') = _ -- no M in the paper
+derive (SMTArr n t) (SMTArr n' t')
+ | Just Refl <- testEquality n n'
+ , Just s <- derive t t'
+ = Just (SpArr s)
+-- sparsity
+derive _ SMTNil = Just SpAbsent
+derive t (SMTMaybe t')
+ | Just s <- derive t t'
+ = Just (SpSparse s)
+-- remainder cannot work
+derive SMTScal{} _ = Nothing
+derive SMTNil _ = Nothing
+derive SMTPair{} _ = Nothing
+derive SMTLEither{} _ = Nothing
+derive SMTMaybe{} _ = Nothing
+derive SMTArr{} _ = Nothing
+
+typesEqual :: SMTy a -> SMTy b -> Bool
+typesEqual a b = case testEquality a b of
+ Just Refl -> True
+ Nothing -> False
+
+main :: IO ()
+main = do
+ putStrLn "# Antisymmetry"
+ forM_ (genTypes 2) $ \(Some ty1) ->
+ forM_ (genTypes 2) $ \(Some ty2) ->
+ when (not (typesEqual ty1 ty2)) $
+ case (derive ty1 ty2, derive ty2 ty1) of
+ (Just s1, Just s2) -> do
+ putStrLn $ ppSMTy 0 ty1 ++ " ≤ " ++ ppSMTy 0 ty2 ++ ": " ++ show s2
+ putStrLn $ ppSMTy 0 ty2 ++ " ≤ " ++ ppSMTy 0 ty1 ++ ": " ++ show s1
+ exitFailure @()
+ _ -> return ()
+
+ putStrLn "# Plus = join"
+ forM_ (genTypes 2) $ \(Some ty) ->
+ forM_ (genSparse ty) $ \(Some s1) ->
+ when (not (stupid s1)) $
+ forM_ (genSparse ty) $ \(Some s2) ->
+ when (not (stupid s2)) $
+ sparsePlusS SF SF ty s1 s2 $ \s3 _ _ _ ->
+ computeJoin ty s1 s2 $ \sJ _s1J _s2J ->
+ case spEqual s3 sJ of
+ Just Refl -> return ()
+ Nothing -> do
+ putStrLn $ "type = " ++ ppSMTy 0 ty
+ putStrLn $ "s1 = " ++ show s1
+ putStrLn $ "s2 = " ++ show s2
+ putStrLn $ "plus -> " ++ show s3
+ putStrLn $ "join -> " ++ show sJ
+ exitFailure @()
+
+ -- print (length (genTypes 0))
+ -- print (length (genTypes 1))
+ -- print (length (genTypes 2))
+ -- print (length (genTypes 3))
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
index 9ddcb35..aabe6d6 100644
--- a/src/CHAD/AST/Pretty.hs
+++ b/src/CHAD/AST/Pretty.hs
@@ -5,7 +5,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
-module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
+module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, ppSparse, PrettyX(..)) where
import Control.Monad (ap)
import Data.List (intersperse, intercalate)