aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Convert.hs
blob: 183d62ca4d3b5eb880da26e99df7eb012b565ab5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Internal.Convert where

import Control.Category
import Data.Proxy
import Data.Type.Equality

import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped


stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
  | Refl <- lemRankMapJust (sshape sarr)
  = mtoRanked arr

rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
  | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
  , Refl <- lemRankMapJust targetsh
  = mcastToShaped arr targetsh

-- | The only constructor that performs runtime shape checking is 'CastXS''.
-- For the other construtors, the types ensure that the shapes are already
-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
data Castable a b where
  CastId  :: Castable a a
  CastCmp :: Castable b c -> Castable a b -> Castable a c

  CastRX  :: Castable a b -> Castable (Ranked n a)           (Mixed (Replicate n Nothing) b)
  CastSX  :: Castable a b -> Castable (Shaped sh a)          (Mixed (MapJust sh) b)

  CastXR  :: Castable a b -> Castable (Mixed sh a)           (Ranked (Rank sh) b)
  CastXS  :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
  CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
          -> Castable a b -> Castable (Mixed sh a)           (Shaped sh' b)

  CastRR  :: Castable a b -> Castable (Ranked n a)           (Ranked n b)
  CastSS  :: Castable a b -> Castable (Shaped sh a)          (Shaped sh b)
  CastXX  :: Castable a b -> Castable (Mixed sh a)           (Mixed sh b)

instance Category Castable where
  id = CastId
  (.) = CastCmp

castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
  where
    -- The 'esh' is the extension shape: the casting happens under a whole
    -- bunch of additional dimensions that it does not touch. These dimensions
    -- are 'esh'.
    -- The strategy is to unwind step-by-step to a large Mixed array, and to
    -- perform the required checks and castings when re-nesting back up.
    go :: Castable a b -> Mixed esh a -> Mixed esh b
    go CastId x = x
    go (CastCmp c1 c2) x = go c1 (go c2 x)
    go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
    go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
    go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
      M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
                                      (go c x)))
    go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
    go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
      | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
      = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
                                    (go c x)))
    go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
    go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
    go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)

    lemRankAppMapJust :: Rank sh ~ Rank sh'
                      => Proxy esh -> Proxy sh -> Proxy sh'
                      -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
    lemRankAppMapJust _ _ _ = unsafeCoerceRefl