1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Nested.Internal.Convert (
stoRanked,
rcastToShaped,
castCastable,
Castable(.., CastXR, CastXS, CastRS),
castRR', castSS',
) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
import GHC.TypeNats
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped
stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
CastInv :: Castable b a -> Castable a b
CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
CastSR :: Elt a => ShS sh -- ^ The singleton is required in case this constructor appears under 'CastInv'.
-> Castable a b -> Castable (Shaped sh a) (Ranked (Rank sh) b)
CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
CastXX' :: (Rank sh ~ Rank sh', Elt a, Elt b) => IShX sh -> IShX sh'
-> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
pattern CastXR :: Castable a b -> Castable (Mixed (Replicate n Nothing) a) (Ranked n b)
pattern CastXR c = CastInv (CastRX (CastInv c))
pattern CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
pattern CastXS c = CastInv (CastSX (CastInv c))
pattern CastRS :: Elt b => ShS sh -> Castable a b -> Castable (Ranked (Rank sh) a) (Shaped sh b)
pattern CastRS sh c = CastInv (CastSR sh (CastInv c))
castRR' :: SNat n -> SNat n' -> Castable a b -> Castable (Ranked n a) (Ranked n' b)
castRR' n@SNat n'@SNat c
| Just Refl <- sameNat n n' = CastRR c
| otherwise = error "castRR': Ranks unequal"
castSS' :: ShS sh -> ShS sh' -> Castable a b -> Castable (Shaped sh a) (Shaped sh' b)
castSS' sh sh' c
| Just Refl <- testEquality sh sh' = CastSS c
| otherwise = error "castSS': Shapes unequal"
instance Category Castable where
id = CastId
(.) = CastCmp
castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
where
-- The 'esh' is the extension shape: the casting happens under a whole
-- bunch of additional dimensions that it does not touch. These dimensions
-- are 'esh'.
-- The strategy is to unwind step-by-step to a large Mixed array, and to
-- perform the required checks and castings when re-nesting back up.
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
go (CastInv c) x = goInv c x
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSR @_ @sh sh c) (M_Shaped (M_Nest @esh esh x))
| Refl <- lemRankMapJust sh
= M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh (MapJust sh) esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
go (CastXX' sh sh' c) (M_Nest esh x)
| Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh)
, Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh')
= M_Nest esh (mcast (ssxFromShape (shxAppend esh sh')) (go c x))
goInv :: Castable b a -> Mixed esh a -> Mixed esh b
goInv CastId x = x
goInv (CastCmp c1 c2) x = goInv c2 (goInv c1 x)
goInv (CastInv c) x = go c x
goInv (CastRX c) (M_Nest esh x) = M_Ranked (M_Nest esh (goInv c x))
goInv (CastSX c) (M_Nest esh x) = M_Shaped (M_Nest esh (goInv c x))
goInv (CastSR @sh sh c) (M_Ranked (M_Nest esh x))
| Refl <- lemRankApp (ssxFromShape esh) (ssxFromSNat (shsRank sh))
, Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape (shCvtSX sh))
, Refl <- lemRankReplicate (shsRank sh)
, Refl <- lemRankMapJust sh
= M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh)))
(goInv c x)))
goInv (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (goInv c x))
goInv (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (goInv c x))
goInv (CastXX c) (M_Nest esh x) = M_Nest esh (goInv c x)
goInv (CastXX' sh sh' c) (M_Nest esh x)
| Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh)
, Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh')
= M_Nest esh (mcast (ssxFromShape (shxAppend esh sh)) (goInv c x))
|