blob: d0e8d2490054c98142acef2aae89fee0661829d9 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Lemmas where
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
lemRankApp :: forall sh1 sh2.
StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp ZKX _ = Refl
lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
= lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
lemRankApp ssh1 ssh2
where
lem :: proxy a -> proxy b -> proxy c
-> c :~: b + a
-> b + a :~: c
lem _ _ _ Refl = Refl
lem2 :: proxy a -> proxy b -> proxy c
-> (a + b :~: c)
-> c + 1 :~: (a + 1 + b)
lem2 _ _ _ Refl = Refl
lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
|