diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 15:11:48 +0100 |
commit | a00234388d1b4e14481067d030bf90031258b756 (patch) | |
tree | 501b6778fc5779ce220aba1e22f56ae60f68d970 /src/Interpreter.hs | |
parent | 7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff) |
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 61 |
1 files changed, 26 insertions, 35 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3cc7ae4..ddc3479 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -188,8 +188,7 @@ zeroD2 typ = case typ of STPair _ _ -> Nothing STEither _ _ -> Nothing STMaybe _ -> Nothing - STArr SZ t -> arrayUnit (zeroD2 t) - STArr n _ -> emptyArray n + STArr _ _ -> Nothing STScal sty -> case sty of STI32 -> () STI64 -> () @@ -215,13 +214,16 @@ addD2s typ a b = case typ of (Nothing, _) -> b (_, Nothing) -> a (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> - let sh1 = arrayShape a - sh2 = arrayShape b - in if | shapeSize sh1 == 0 -> b - | shapeSize sh2 == 0 -> a - | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i)) - | otherwise -> error "Plus of inconsistently shaped arrays" + STArr _ t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> + let sh1 = arrayShape x + sh2 = arrayShape y + in if | shapeSize sh1 == 0 -> Just y + | shapeSize sh2 == 0 -> Just x + | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) + | otherwise -> error "Plus of inconsistently shaped arrays" STScal sty -> case sty of STI32 -> () STI64 -> () @@ -238,7 +240,7 @@ onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx v onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = - runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx + Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) withAccum t _ initval f = AcM $ do @@ -253,7 +255,7 @@ newAcZero = \case STPair{} -> newIORef Nothing STEither{} -> newIORef Nothing STMaybe _ -> newIORef Nothing - STArr n _ -> newIORef (emptyArray n) + STArr _ _ -> newIORef Nothing STScal sty -> case sty of STI32 -> return () STI64 -> return () @@ -268,7 +270,7 @@ newAcSparse typ prj idx val = case (typ, prj) of (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val + (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val (STScal sty, SAPHere) -> case sty of STI32 -> return () STI64 -> return () @@ -286,7 +288,7 @@ newAcSparse typ prj idx val = case (typ, prj) of (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val + (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val (STAccum{}, _) -> error "Accumulators not allowed in source program" @@ -309,7 +311,7 @@ readAcSparse typ val = case typ of STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (readAcSparse t) =<< readIORef val + STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val STScal sty -> case sty of STI32 -> return () STI64 -> return () @@ -360,32 +362,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of (\ac -> accumAddSparse t1 prj' ac idx val) (STArr _ t1, SAPHere) -> - let add ac = forM_ [0 .. arraySize ac - 1] $ \i -> - unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i) - in if arraySize val == 0 - then return () - else AcM $ join $ atomicModifyIORef' ref $ \ac -> - if arraySize ac == 0 - then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val - join $ atomicModifyIORef' ref $ \ac' -> - if arraySize ac == 0 - then (newac, return ()) - else (ac', add ac')) - else (ac, add ac) + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref + (arrayMapM (newAcSparse t1 SAPHere ()) val') + (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> + accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) (STArr n t1, SAPArrIdx prj' _) -> let ((arrindex', arrsh'), idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' arrsh = unTupRepIdx ShNil ShCons n arrsh' linindex = toLinearIndex arrsh arrindex - add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val - in AcM $ join $ atomicModifyIORef' ref $ \ac -> - if arraySize ac == 0 - then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx - join $ atomicModifyIORef' ref $ \ac' -> - if arraySize ac == 0 - then (newac, return ()) - else (ac', add ac')) - else (ac, add ac) + in realiseMaybeSparse ref + (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) + (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) (STScal sty, SAPHere) -> AcM $ case sty of STI32 -> return () |