aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Lemmas.hs
blob: ded6af5e9f46a7b7616276a051e1301f14650761 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Lemmas where

import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits

import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Permutation
import Data.Array.Nested.Types


-- * Reasoning helpers

subst1 :: forall f a b. a :~: b -> f a :~: f b
subst1 Refl = Refl

subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl


-- * Lemmas

-- ** Nat

lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
lemLeqSuccSucc _ _ = unsafeCoerceRefl

lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl


-- ** Append

lemAppNil :: l ++ '[] :~: l
lemAppNil = unsafeCoerceRefl

lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerceRefl

lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl


-- ** Rank

lemRankApp :: forall sh1 sh2.
              StaticShX sh1 -> StaticShX sh2
           -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp ZKX _ = Refl
lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
  = lem (Proxy @(Rank sh1T)) Proxy Proxy $
      sym (lemRankApp ssh1 ssh2)
  where
    lem :: proxy a -> proxy b -> proxy c
        -> (a + b :~: c)
        -> c + 1 :~: (a + 1 + b)
    lem _ _ _ Refl = Refl

lemRankAppComm :: proxy sh1 -> proxy sh2
               -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl

lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = unsafeCoerceRefl


-- ** Various type families

lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp sn _ _ = go sn
  where
    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
    go SZ = Refl
    go (SS (n :: SNat n'm1))
      | Refl <- lemReplicateSucc @a @n'm1
      , Refl <- go n
      = sym (lemReplicateSucc @a @(n'm1 + m))

lemDropLenApp :: Rank l1 <= Rank l2
              => Proxy l1 -> Proxy l2 -> Proxy rest
              -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
lemDropLenApp _ _ _ = unsafeCoerceRefl

lemTakeLenApp :: Rank l1 <= Rank l2
              => Proxy l1 -> Proxy l2 -> Proxy rest
              -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerceRefl

lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l
lemInitApp _ _ = unsafeCoerceRefl

lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x
lemLastApp _ _ = unsafeCoerceRefl


-- ** KnownNat

lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
lemKnownNatSucc = Dict

lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict

lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict


-- ** Known shapes

lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)

lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict