diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-03-27 21:13:31 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-03-27 21:13:31 +0100 |
commit | cb31e179971293c519a530d8ce8ccc004458b1c4 (patch) | |
tree | a760f9ca2ea4048f1410a2b24500560e35f8ab19 /src/Array.hs | |
parent | 95f48df1b97529311a41245bbaaf4781b5ffaa4b (diff) |
Nats
Diffstat (limited to 'src/Array.hs')
-rw-r--r-- | src/Array.hs | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/src/Array.hs b/src/Array.hs index 25db19e..5140eaf 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -8,8 +8,6 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Array where import qualified Data.Array.RankedU as U @@ -17,12 +15,10 @@ import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU -import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) +import Nats -data Dict c a where - Dict :: c a => Dict c a type family l1 ++ l2 where '[] ++ l2 = l2 @@ -54,16 +50,16 @@ class KnownShapeX sh where instance KnownShapeX '[] where knownShapeX = SZX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing @n :$@ knownShapeX + knownShapeX = knownNat :$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :$? knownShapeX type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = 1 + Rank sh + Rank '[] = Z + Rank (_ : sh) = S (Rank sh) type XArray :: [Maybe Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (Rank sh) a) +data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) zeroIdx :: StaticShapeX sh -> IxX sh zeroIdx SZX = IZX @@ -150,27 +146,33 @@ lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX SZX = Dict -lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma" lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma" +lemAppKnownShapeX (n :$@ ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + , Dict <- snatKnown n + = Dict +lemAppKnownShapeX (() :$? ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + = Dict shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IxX sh' go SZX [] = IZX - go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l + go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a -fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v) +fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.fromVector (shapeLshape sh) v) toVector :: U.Unbox a => XArray sh a -> VU.Vector a toVector (XArray arr) = U.toVector arr |