summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Interpreter.hs20
1 files changed, 4 insertions, 16 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 2c63b24..da5b73c 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -244,18 +244,6 @@ data Inverted (f :: Nat -> Type) n where
type InvShape = Inverted Shape
type InvIndex = Inverted Index
-pattern IIxNil :: () => n ~ Z => InvIndex n
-pattern IIxNil = InvNil
-pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn
-pattern IIxCons i ix = InvCons i ix
-{-# COMPLETE IIxNil, IIxCons #-}
-
-pattern IShNil :: () => n ~ Z => InvShape n
-pattern IShNil = InvNil
-pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn
-pattern IShCons n sh = InvCons n sh
-{-# COMPLETE IShNil, IShCons #-}
-
class Shapey f where
shapeyNil :: f Z
shapeyCons :: f n -> Int -> f (S n)
@@ -288,13 +276,13 @@ uninvert = go shapeyNil
piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m)
piindexMatch PIIxEnd ix = Just ix
-piindexMatch (PIIxCons i pix) (IIxCons i' ix)
+piindexMatch (PIIxCons i pix) (InvCons i' ix)
| i == i' = piindexMatch pix ix
| otherwise = Nothing
piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
-piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix)
+piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
@@ -456,8 +444,8 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
-> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray
-> r
go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray
- go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element
- go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 =
+ go (SS dep') InvNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element
+ go (SS dep') (InvCons _ ish) (i, idx') val' kj k0 =
go dep' ish idx' val'
(\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj)
(\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm)