aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Lemmas.hs
blob: a3b20c6f46bf1fcc00f7c01712b0adb3b6b6aa06 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Lemmas where

import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits

import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types


-- * Basic Lemmas (they don't mention shape types and don't require typing plugins)

-- ** Nat

lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
lemLeqSuccSucc _ _ = unsafeCoerceRefl

lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl

-- ** Append

lemAppNil :: l ++ '[] :~: l
lemAppNil = unsafeCoerceRefl

lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerceRefl

lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl

-- ** Simple type families

lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp sn _ _ = go sn
  where
    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
    go SZ = Refl
    go (SS (n :: SNat n'm1))
      | Refl <- lemReplicateSucc @a @n'm1
      , Refl <- go n
      = sym (lemReplicateSucc @a @(n'm1 + m))

lemDropLenApp :: Rank l1 <= Rank l2
              => Proxy l1 -> Proxy l2 -> Proxy rest
              -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
lemDropLenApp _ _ _ = unsafeCoerceRefl

lemTakeLenApp :: Rank l1 <= Rank l2
              => Proxy l1 -> Proxy l2 -> Proxy rest
              -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerceRefl

lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
lemInitApp _ _ = unsafeCoerceRefl

lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
lemLastApp _ _ = unsafeCoerceRefl


-- ** KnownNat

lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
lemKnownNatSucc = Dict

lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict

lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict


-- * Complex lemmas (they mention shape types)

-- ** Known shapes

lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)

lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict

lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
  where
    go :: ShS sh' -> StaticShX (MapJust sh')
    go ZSS = ZKX
    go (n :$$ sh) = SKnown n :!% go sh

-- ** Rank

lemRankApp :: forall sh1 sh2.
              StaticShX sh1 -> StaticShX sh2
           -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp ZKX _ = Refl
lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
  = lem (Proxy @(Rank sh1T)) Proxy Proxy $
      sym (lemRankApp ssh1 ssh2)
  where
    lem :: proxy a -> proxy b -> proxy c
        -> (a + b :~: c)
        -> c + 1 :~: (a + 1 + b)
    lem _ _ _ Refl = Refl

lemRankAppComm :: proxy sh1 -> proxy sh2
               -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl

lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = unsafeCoerceRefl

lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
lemRankMapJust ZSS = Refl
lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl

-- ** Related to MapJust and/or Permutation

lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
lemTakeLenMapJust PNil _ = Refl
lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl
lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty"

lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
lemDropLenMapJust PNil _ = Refl
lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl
lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty"

lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
lemIndexMapJust SZ (_ :$$ _) = Refl
lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
  | Refl <- lemIndexMapJust i sh
  , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
  , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
  = Refl
lemIndexMapJust _ ZSS = error "Index of empty"

lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
lemPermuteMapJust PNil _ = Refl
lemPermuteMapJust (i `PCons` is) sh
  | Refl <- lemPermuteMapJust is sh
  , Refl <- lemIndexMapJust i sh
  = Refl

lemMapJustApp :: ShS sh1 -> Proxy sh2
              -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
lemMapJustApp ZSS _ = Refl
lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl