diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 20 |
1 files changed, 4 insertions, 16 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 2c63b24..da5b73c 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -244,18 +244,6 @@ data Inverted (f :: Nat -> Type) n where type InvShape = Inverted Shape type InvIndex = Inverted Index -pattern IIxNil :: () => n ~ Z => InvIndex n -pattern IIxNil = InvNil -pattern IIxCons :: () => S n ~ succn => Int -> InvIndex n -> InvIndex succn -pattern IIxCons i ix = InvCons i ix -{-# COMPLETE IIxNil, IIxCons #-} - -pattern IShNil :: () => n ~ Z => InvShape n -pattern IShNil = InvNil -pattern IShCons :: () => S n ~ succn => Int -> InvShape n -> InvShape succn -pattern IShCons n sh = InvCons n sh -{-# COMPLETE IShNil, IShCons #-} - class Shapey f where shapeyNil :: f Z shapeyCons :: f n -> Int -> f (S n) @@ -288,13 +276,13 @@ uninvert = go shapeyNil piindexMatch :: PartialInvIndex n m -> InvIndex n -> Maybe (InvIndex m) piindexMatch PIIxEnd ix = Just ix -piindexMatch (PIIxCons i pix) (IIxCons i' ix) +piindexMatch (PIIxCons i pix) (InvCons i' ix) | i == i' = piindexMatch pix ix | otherwise = Nothing piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n piindexConcat PIIxEnd ix = ix -piindexConcat (PIIxCons i pix) ix = IIxCons i (piindexConcat pix ix) +piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of @@ -456,8 +444,8 @@ accumAddSparse typ (SS dep) ref idx val = case typ of -> (forall m. PartialInvIndex n m -> Shape m -> Rep (TArr m t) -> r) -- ^ Accumulating onto a subarray -> r go SZ ish () val' _ k0 = k0 PIIxEnd (uninvert ish) val' -- ^ Ran out of AcIdx: accumulating onto subarray - go (SS dep') IShNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element - go (SS dep') (IShCons _ ish) (i, idx') val' kj k0 = + go (SS dep') InvNil idx' val' kj _ = kj dep' IxNil idx' val' -- ^ Ran out of array dimensions: accumulating into (part of) element + go (SS dep') (InvCons _ ish) (i, idx') val' kj k0 = go dep' ish idx' val' (\j index idxj valj -> kj j (IxCons index (fromIntegral @Int64 @Int i)) idxj valj) (\pidxm shm valm -> k0 (PIIxCons (fromIntegral @Int64 @Int i) pidxm) shm valm) |