aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Lemmas.hs
blob: ff2e45ca5f26b45e134b316631a9d22ffc26709f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Lemmas where

import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits

import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types


-- * Reasoning helpers

subst1 :: forall f a b. a :~: b -> f a :~: f b
subst1 Refl = Refl

subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
subst2 Refl = Refl


-- * Lemmas

-- ** Nat

lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
lemLeqSuccSucc _ _ = unsafeCoerceRefl

lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl


-- ** Append

lemAppNil :: l ++ '[] :~: l
lemAppNil = unsafeCoerceRefl

lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerceRefl

lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
lemAppLeft _ Refl = Refl


-- ** Rank

lemRankApp :: forall sh1 sh2.
              StaticShX sh1 -> StaticShX sh2
           -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp ZKX _ = Refl
lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
  = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
      lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
        lemRankApp ssh1 ssh2
  where
    lem :: proxy a -> proxy b -> proxy c
        -> c :~: b + a
        -> b + a :~: c
    lem _ _ _ Refl = Refl

    lem2 :: proxy a -> proxy b -> proxy c
         -> (a + b :~: c)
         -> c + 1 :~: (a + 1 + b)
    lem2 _ _ _ Refl = Refl

lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
               -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl  -- TODO improve this

lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate SZ = Refl
lemRankReplicate (SS (n :: SNat nm1))
  | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
  , Refl <- lemRankReplicate n
  = Refl


-- ** Various type families

lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp sn _ _ = go sn
  where
    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
    go SZ = Refl
    go (SS (n :: SNat n'm1))
      | Refl <- lemReplicateSucc @a @n'm1
      , Refl <- go n
      = sym (lemReplicateSucc @a @(n'm1 + m))

lemDropLenApp :: Rank l1 <= Rank l2
              => Proxy l1 -> Proxy l2 -> Proxy rest
              -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
lemDropLenApp _ _ _ = unsafeCoerceRefl

lemTakeLenApp :: Rank l1 <= Rank l2
              => Proxy l1 -> Proxy l2 -> Proxy rest
              -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerceRefl


-- ** KnownNat

lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
lemKnownNatSucc = Dict

lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict

lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict


-- ** Known shapes

lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)

lemKnownShX :: StaticShX sh -> Dict KnownShX sh
lemKnownShX ZKX = Dict
lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict