aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Convert.hs
blob: 8458efe1467aca37c369d6b7a4559a93d1782647 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Nested.Internal.Convert (
  stoRanked,
  rcastToShaped,
  castCastable,
  Castable(.., CastXR, CastXS, CastRS),
  castRR', castSS',
) where

import Control.Category
import Data.Proxy
import Data.Type.Equality
import GHC.TypeNats

import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped


stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
  | Refl <- lemRankMapJust (sshape sarr)
  = mtoRanked arr

rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
  | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
  , Refl <- lemRankMapJust targetsh
  = mcastToShaped arr targetsh

data Castable a b where
  CastId  :: Castable a a
  CastCmp :: Castable b c -> Castable a b -> Castable a c
  CastInv :: Castable b a -> Castable a b

  CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
  CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
  CastSR :: Elt a => ShS sh  -- ^ The singleton is required in case this constructor appears under 'CastInv'.
         -> Castable a b -> Castable (Shaped sh a) (Ranked (Rank sh) b)

  CastRR :: Castable a b -> Castable (Ranked n a)  (Ranked n b)
  CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
  CastXX :: Castable a b -> Castable (Mixed sh a)  (Mixed sh b)

  CastXX' :: (Rank sh ~ Rank sh', Elt a, Elt b) => IShX sh -> IShX sh'
          -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)

pattern CastXR :: Castable a b -> Castable (Mixed (Replicate n Nothing) a) (Ranked n b)
pattern CastXR c = CastInv (CastRX (CastInv c))

pattern CastXS  :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
pattern CastXS c = CastInv (CastSX (CastInv c))

pattern CastRS :: Elt b => ShS sh -> Castable a b -> Castable (Ranked (Rank sh) a) (Shaped sh b)
pattern CastRS sh c = CastInv (CastSR sh (CastInv c))

castRR' :: SNat n -> SNat n' -> Castable a b -> Castable (Ranked n a) (Ranked n' b)
castRR' n@SNat n'@SNat c
  | Just Refl <- sameNat n n' = CastRR c
  | otherwise = error "castRR': Ranks unequal"

castSS' :: ShS sh -> ShS sh' -> Castable a b -> Castable (Shaped sh a) (Shaped sh' b)
castSS' sh sh' c
  | Just Refl <- testEquality sh sh' = CastSS c
  | otherwise = error "castSS': Shapes unequal"

instance Category Castable where
  id = CastId
  (.) = CastCmp

castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
  where
    -- The 'esh' is the extension shape: the casting happens under a whole
    -- bunch of additional dimensions that it does not touch. These dimensions
    -- are 'esh'.
    -- The strategy is to unwind step-by-step to a large Mixed array, and to
    -- perform the required checks and castings when re-nesting back up.
    go :: Castable a b -> Mixed esh a -> Mixed esh b
    go CastId x = x
    go (CastCmp c1 c2) x = go c1 (go c2 x)
    go (CastInv c) x = goInv c x
    go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
    go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
    go (CastSR @_ @sh sh c) (M_Shaped (M_Nest @esh esh x))
      | Refl <- lemRankMapJust sh
      = M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh (MapJust sh) esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
                                        (go c x)))
    go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
    go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
    go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
    go (CastXX' sh sh' c) (M_Nest esh x)
      | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh)
      , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh')
      = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh')) (go c x))

    goInv :: Castable b a -> Mixed esh a -> Mixed esh b
    goInv CastId x = x
    goInv (CastCmp c1 c2) x = goInv c2 (goInv c1 x)
    goInv (CastInv c) x = go c x
    goInv (CastRX c) (M_Nest esh x) = M_Ranked (M_Nest esh (goInv c x))
    goInv (CastSX c) (M_Nest esh x) = M_Shaped (M_Nest esh (goInv c x))
    goInv (CastSR @sh sh c) (M_Ranked (M_Nest esh x))
      | Refl <- lemRankApp (ssxFromShape esh) (ssxFromSNat (shsRank sh))
      , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape (shCvtSX sh))
      , Refl <- lemRankReplicate (shsRank sh)
      , Refl <- lemRankMapJust sh
      = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh)))
                                    (goInv c x)))
    goInv (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (goInv c x))
    goInv (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (goInv c x))
    goInv (CastXX c) (M_Nest esh x) = M_Nest esh (goInv c x)
    goInv (CastXX' sh sh' c) (M_Nest esh x)
      | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh)
      , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh')
      = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh)) (goInv c x))