1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
|
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Control.Monad
import Data.Type.Equality
import Data.Some
import System.Exit
import CHAD.AST.Types
import CHAD.AST.Pretty
import CHAD.AST.Sparse
import CHAD.Data
genTypes :: Int -> [Some SMTy]
genTypes dep | dep < 0 = []
genTypes dep =
[Some SMTNil, Some (SMTScal STF64)] ++
concat [[Some (SMTMaybe t), Some (SMTArr (SS SZ) t)]
| Some t <- genTypes (dep - 1)] ++
concat [[Some (SMTPair a b), Some (SMTLEither a b)]
| Some a <- genTypes (dep - 1), Some b <- genTypes (dep - 1)]
genSparse :: SMTy t -> [Some (Sparse t)]
genSparse SMTNil =
[Some SpAbsent]
genSparse (SMTPair a b) =
[Some SpAbsent] ++
concat [[Some (SpPair s1 s2), Some (SpSparse (SpPair s1 s2))]
| Some s1 <- genSparse a, Some s2 <- genSparse b]
genSparse (SMTLEither a b) =
[Some SpAbsent] ++ [Some (SpLEither s1 s2) | Some s1 <- genSparse a, Some s2 <- genSparse b]
genSparse (SMTMaybe t) =
[Some SpAbsent] ++ [Some (SpMaybe s) | Some s <- genSparse t]
genSparse (SMTArr _ t) =
[Some SpAbsent] ++ concat [[Some (SpArr s), Some (SpSparse (SpArr s))] | Some s <- genSparse t]
genSparse (SMTScal _) =
[Some SpAbsent, Some SpScal]
computeJoin :: SMTy t -> Sparse t t1 -> Sparse t t2
-> (forall t3. Sparse t t3 -> Sparse t3 t1 -> Sparse t3 t2 -> r) -> r
computeJoin t SpAbsent s2 k = k s2 SpAbsent (spDense (applySparse s2 t))
computeJoin t s1 SpAbsent k = k s1 (spDense (applySparse s1 t)) SpAbsent
computeJoin t (SpSparse s1) (SpSparse s2) k =
computeJoin t s1 s2 $ \s3 s13 s23 ->
k (SpSparse s3) (SpMaybe s13) (SpMaybe s23)
computeJoin t (SpSparse s1) s2 k =
computeJoin t s1 s2 $ \s3 s13 s23 ->
k s3 (SpSparse s13) s23
computeJoin t s1 (SpSparse s2) k =
computeJoin t s1 s2 $ \s3 s13 s23 ->
k s3 s13 (SpSparse s23)
computeJoin (SMTPair a b) (SpPair sa1 sb1) (SpPair sa2 sb2) k =
computeJoin a sa1 sa2 $ \sa3 sa13 sa23 ->
computeJoin b sb1 sb2 $ \sb3 sb13 sb23 ->
k (SpPair sa3 sb3) (SpPair sa13 sb13) (SpPair sa23 sb23)
computeJoin (SMTLEither a b) (SpLEither sa1 sb1) (SpLEither sa2 sb2) k =
computeJoin a sa1 sa2 $ \sa3 sa13 sa23 ->
computeJoin b sb1 sb2 $ \sb3 sb13 sb23 ->
k (SpLEither sa3 sb3) (SpLEither sa13 sb13) (SpLEither sa23 sb23)
computeJoin (SMTMaybe t) (SpMaybe s1) (SpMaybe s2) k =
computeJoin t s1 s2 $ \s3 s13 s23 ->
k (SpMaybe s3) (SpMaybe s13) (SpMaybe s23)
computeJoin (SMTArr _ t) (SpArr s1) (SpArr s2) k =
computeJoin t s1 s2 $ \s3 s13 s23 ->
k (SpArr s3) (SpArr s13) (SpArr s23)
computeJoin (SMTScal _) SpScal SpScal k = k SpScal SpScal SpScal
-- Checks that the sparsity structures are equal, not just the returned types
spEqual :: Sparse t t1 -> Sparse t t2 -> Maybe (t1 :~: t2)
spEqual (SpSparse s1) (SpSparse s2) | Just Refl <- spEqual s1 s2 = Just Refl
spEqual (SpSparse _) _ = Nothing
spEqual _ (SpSparse _) = Nothing
spEqual SpAbsent SpAbsent = Just Refl
spEqual SpAbsent _ = Nothing
spEqual _ SpAbsent = Nothing
spEqual (SpPair sa1 sb1) (SpPair sa2 sb2)
| Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl
| otherwise = Nothing
spEqual (SpLEither sa1 sb1) (SpLEither sa2 sb2)
| Just Refl <- spEqual sa1 sa2, Just Refl <- spEqual sb1 sb2 = Just Refl
| otherwise = Nothing
spEqual (SpMaybe s1) (SpMaybe s2)
| Just Refl <- spEqual s1 s2 = Just Refl
| otherwise = Nothing
spEqual (SpArr s1) (SpArr s2)
| Just Refl <- spEqual s1 s2 = Just Refl
| otherwise = Nothing
spEqual SpScal SpScal = Just Refl
stupid :: Sparse t t' -> Bool
stupid (SpMaybe (SpSparse _)) = True
stupid (SpMaybe SpAbsent) = True
stupid (SpSparse s) = stupid s
stupid SpAbsent = False
stupid (SpPair s1 s2) = stupid s1 || stupid s2
stupid (SpLEither s1 s2) = stupid s1 || stupid s2
stupid (SpArr s) = stupid s
stupid (SpMaybe s) = stupid s
stupid SpScal = False
-- derive big small
derive :: SMTy t1 -> SMTy t2 -> Maybe (Sparse t1 t2)
-- dense recursion
derive (SMTScal t) (SMTScal t')
| Just Refl <- testEquality t t'
= Just SpScal
derive (SMTPair a b) (SMTPair a' b')
| Just s1 <- derive a a'
, Just s2 <- derive b b'
= Just (SpPair s1 s2)
derive (SMTLEither a b) (SMTLEither a' b')
| Just s1 <- derive a a'
, Just s2 <- derive b b'
= Just (SpLEither s1 s2)
-- derive (SMTMaybe t) (SMTMaybe t') = _ -- no M in the paper
derive (SMTArr n t) (SMTArr n' t')
| Just Refl <- testEquality n n'
, Just s <- derive t t'
= Just (SpArr s)
-- sparsity
derive _ SMTNil = Just SpAbsent
derive t (SMTMaybe t')
| Just s <- derive t t'
= Just (SpSparse s)
-- remainder cannot work
derive SMTScal{} _ = Nothing
derive SMTNil _ = Nothing
derive SMTPair{} _ = Nothing
derive SMTLEither{} _ = Nothing
derive SMTMaybe{} _ = Nothing
derive SMTArr{} _ = Nothing
typesEqual :: SMTy a -> SMTy b -> Bool
typesEqual a b = case testEquality a b of
Just Refl -> True
Nothing -> False
main :: IO ()
main = do
putStrLn "# Antisymmetry"
forM_ (genTypes 2) $ \(Some ty1) ->
forM_ (genTypes 2) $ \(Some ty2) ->
when (not (typesEqual ty1 ty2)) $
case (derive ty1 ty2, derive ty2 ty1) of
(Just s1, Just s2) -> do
putStrLn $ ppSMTy 0 ty1 ++ " ≤ " ++ ppSMTy 0 ty2 ++ ": " ++ show s2
putStrLn $ ppSMTy 0 ty2 ++ " ≤ " ++ ppSMTy 0 ty1 ++ ": " ++ show s1
exitFailure @()
_ -> return ()
putStrLn "# Plus = join"
forM_ (genTypes 2) $ \(Some ty) ->
forM_ (genSparse ty) $ \(Some s1) ->
when (not (stupid s1)) $
forM_ (genSparse ty) $ \(Some s2) ->
when (not (stupid s2)) $
sparsePlusS SF SF ty s1 s2 $ \s3 _ _ _ ->
computeJoin ty s1 s2 $ \sJ _s1J _s2J ->
case spEqual s3 sJ of
Just Refl -> return ()
Nothing -> do
putStrLn $ "type = " ++ ppSMTy 0 ty
putStrLn $ "s1 = " ++ show s1
putStrLn $ "s2 = " ++ show s2
putStrLn $ "plus -> " ++ show s3
putStrLn $ "join -> " ++ show sJ
exitFailure @()
-- print (length (genTypes 0))
-- print (length (genTypes 1))
-- print (length (genTypes 2))
-- print (length (genTypes 3))
|