summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 15:11:48 +0100
commita00234388d1b4e14481067d030bf90031258b756 (patch)
tree501b6778fc5779ce220aba1e22f56ae60f68d970 /src/Interpreter.hs
parent7971f6dff12bc7b66a5d4ae91a6791ac08872c31 (diff)
D2[Array] now has a Maybe instead of zero-size for zero
Remaining problem: 'add' in Compile doesn't use the D2 stuff
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs61
1 files changed, 26 insertions, 35 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 3cc7ae4..ddc3479 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -188,8 +188,7 @@ zeroD2 typ = case typ of
STPair _ _ -> Nothing
STEither _ _ -> Nothing
STMaybe _ -> Nothing
- STArr SZ t -> arrayUnit (zeroD2 t)
- STArr n _ -> emptyArray n
+ STArr _ _ -> Nothing
STScal sty -> case sty of
STI32 -> ()
STI64 -> ()
@@ -215,13 +214,16 @@ addD2s typ a b = case typ of
(Nothing, _) -> b
(_, Nothing) -> a
(Just x, Just y) -> Just (addD2s t x y)
- STArr _ t ->
- let sh1 = arrayShape a
- sh2 = arrayShape b
- in if | shapeSize sh1 == 0 -> b
- | shapeSize sh2 == 0 -> a
- | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i))
- | otherwise -> error "Plus of inconsistently shaped arrays"
+ STArr _ t -> case (a, b) of
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just x, Just y) ->
+ let sh1 = arrayShape x
+ sh2 = arrayShape y
+ in if | shapeSize sh1 == 0 -> Just y
+ | shapeSize sh2 == 0 -> Just x
+ | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
STScal sty -> case sty of
STI32 -> ()
STI64 -> ()
@@ -238,7 +240,7 @@ onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx v
onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
- runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
+ Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
withAccum t _ initval f = AcM $ do
@@ -253,7 +255,7 @@ newAcZero = \case
STPair{} -> newIORef Nothing
STEither{} -> newIORef Nothing
STMaybe _ -> newIORef Nothing
- STArr n _ -> newIORef (emptyArray n)
+ STArr _ _ -> newIORef Nothing
STScal sty -> case sty of
STI32 -> return ()
STI64 -> return ()
@@ -268,7 +270,7 @@ newAcSparse typ prj idx val = case (typ, prj) of
(STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
(STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
(STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
- (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
+ (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val
(STScal sty, SAPHere) -> case sty of
STI32 -> return ()
STI64 -> return ()
@@ -286,7 +288,7 @@ newAcSparse typ prj idx val = case (typ, prj) of
(STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
- (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val
+ (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val
(STAccum{}, _) -> error "Accumulators not allowed in source program"
@@ -309,7 +311,7 @@ readAcSparse typ val = case typ of
STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
STMaybe t -> traverse (readAcSparse t) =<< readIORef val
- STArr _ t -> traverse (readAcSparse t) =<< readIORef val
+ STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val
STScal sty -> case sty of
STI32 -> return ()
STI64 -> return ()
@@ -360,32 +362,21 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of
(\ac -> accumAddSparse t1 prj' ac idx val)
(STArr _ t1, SAPHere) ->
- let add ac = forM_ [0 .. arraySize ac - 1] $ \i ->
- unAcM $ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val i)
- in if arraySize val == 0
- then return ()
- else AcM $ join $ atomicModifyIORef' ref $ \ac ->
- if arraySize ac == 0
- then (ac, do newac <- arrayMapM (newAcSparse t1 SAPHere ()) val
- join $ atomicModifyIORef' ref $ \ac' ->
- if arraySize ac == 0
- then (newac, return ())
- else (ac', add ac'))
- else (ac, add ac)
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref
+ (arrayMapM (newAcSparse t1 SAPHere ()) val')
+ (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
+ accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
(STArr n t1, SAPArrIdx prj' _) ->
let ((arrindex', arrsh'), idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = unTupRepIdx ShNil ShCons n arrsh'
linindex = toLinearIndex arrsh arrindex
- add ac = unAcM $ accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val
- in AcM $ join $ atomicModifyIORef' ref $ \ac ->
- if arraySize ac == 0
- then (ac, do newac <- onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx
- join $ atomicModifyIORef' ref $ \ac' ->
- if arraySize ac == 0
- then (newac, return ())
- else (ac', add ac'))
- else (ac, add ac)
+ in realiseMaybeSparse ref
+ (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
+ (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
(STScal sty, SAPHere) -> AcM $ case sty of
STI32 -> return ()