blob: ff2e45ca5f26b45e134b316631a9d22ffc26709f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Lemmas where
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
-- * Reasoning helpers
subst1 :: forall f a b. a :~: b -> f a :~: f b
subst1 Refl = Refl
subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl
-- * Lemmas
-- ** Nat
lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
lemLeqSuccSucc _ _ = unsafeCoerceRefl
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
-- ** Append
lemAppNil :: l ++ '[] :~: l
lemAppNil = unsafeCoerceRefl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerceRefl
lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl
-- ** Rank
lemRankApp :: forall sh1 sh2.
StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp ZKX _ = Refl
lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
= lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
lemRankApp ssh1 ssh2
where
lem :: proxy a -> proxy b -> proxy c
-> c :~: b + a
-> b + a :~: c
lem _ _ _ Refl = Refl
lem2 :: proxy a -> proxy b -> proxy c
-> (a + b :~: c)
-> c + 1 :~: (a + 1 + b)
lem2 _ _ _ Refl = Refl
lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate SZ = Refl
lemRankReplicate (SS (n :: SNat nm1))
| Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
, Refl <- lemRankReplicate n
= Refl
-- ** Various type families
lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp sn _ _ = go sn
where
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS (n :: SNat n'm1))
| Refl <- lemReplicateSucc @a @n'm1
, Refl <- go n
= sym (lemReplicateSucc @a @(n'm1 + m))
lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
-> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
lemDropLenApp _ _ _ = unsafeCoerceRefl
lemTakeLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
-> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerceRefl
-- ** KnownNat
lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
lemKnownNatSucc = Dict
lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
-- ** Known shapes
lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
|