summaryrefslogtreecommitdiff
path: root/src/Array.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-27 21:13:31 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-27 21:13:31 +0100
commitcb31e179971293c519a530d8ce8ccc004458b1c4 (patch)
treea760f9ca2ea4048f1410a2b24500560e35f8ab19 /src/Array.hs
parent95f48df1b97529311a41245bbaaf4781b5ffaa4b (diff)
Nats
Diffstat (limited to 'src/Array.hs')
-rw-r--r--src/Array.hs36
1 files changed, 19 insertions, 17 deletions
diff --git a/src/Array.hs b/src/Array.hs
index 25db19e..5140eaf 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -8,8 +8,6 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Array where
import qualified Data.Array.RankedU as U
@@ -17,12 +15,10 @@ import Data.Kind
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
-import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
+import Nats
-data Dict c a where
- Dict :: c a => Dict c a
type family l1 ++ l2 where
'[] ++ l2 = l2
@@ -54,16 +50,16 @@ class KnownShapeX sh where
instance KnownShapeX '[] where
knownShapeX = SZX
instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = natSing @n :$@ knownShapeX
+ knownShapeX = knownNat :$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :$? knownShapeX
type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = 1 + Rank sh
+ Rank '[] = Z
+ Rank (_ : sh) = S (Rank sh)
type XArray :: [Maybe Nat] -> Type -> Type
-data XArray sh a = XArray (U.Array (Rank sh) a)
+data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
zeroIdx :: StaticShapeX sh -> IxX sh
zeroIdx SZX = IZX
@@ -150,27 +146,33 @@ lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
lemKnownShapeX SZX = Dict
-lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict
lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma"
lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict
-lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict
-lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma"
+lemAppKnownShapeX (n :$@ ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ , Dict <- snatKnown n
+ = Dict
+lemAppKnownShapeX (() :$? ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ = Dict
shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IxX sh'
go SZX [] = IZX
- go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
+ go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
-fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v)
+fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector sh v
+ | Dict <- lemKnownNatRank sh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.fromVector (shapeLshape sh) v)
toVector :: U.Unbox a => XArray sh a -> VU.Vector a
toVector (XArray arr) = U.toVector arr