aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Lemmas.hs
blob: d0e8d2490054c98142acef2aae89fee0661829d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed.Lemmas where

import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits

import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types


lemRankApp :: forall sh1 sh2.
              StaticShX sh1 -> StaticShX sh2
           -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp ZKX _ = Refl
lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
  = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
      lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
        lemRankApp ssh1 ssh2
  where
    lem :: proxy a -> proxy b -> proxy c
        -> c :~: b + a
        -> b + a :~: c
    lem _ _ _ Refl = Refl

    lem2 :: proxy a -> proxy b -> proxy c
         -> (a + b :~: c)
         -> c + 1 :~: (a + 1 + b)
    lem2 _ _ _ Refl = Refl

lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
               -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl  -- TODO improve this

lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict

lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict