From 919a36f8eed21501357185a90e2b7a4d9eaf7f08 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 28 Apr 2025 12:05:06 +0200 Subject: WIP interpreter support for new monoidal accumulators --- src/Interpreter.hs | 209 +++++++++++++++++++++++------------------------------ 1 file changed, 89 insertions(+), 120 deletions(-) (limited to 'src/Interpreter.hs') diff --git a/src/Interpreter.hs b/src/Interpreter.hs index f8e7e98..af11de8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -164,15 +164,16 @@ interpret'Rec env = \case accum <- interpret' env e3 accumAddSparse t p accum idx val EZero _ t ezi -> do - return $ zeroD2 t ezi + zi <- interpret' env ezi + return $ zeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b - return $ addD2s t a' b' + return $ addM t a' b' EOneHot _ t p a b -> do a' <- interpret' env a b' <- interpret' env b - return $ onehotD2 p t a' b' + return $ onehotM p t a' b' EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -202,145 +203,113 @@ interpretOp op arg = case op of styIsEq STF64 = id styIsEq STBool = id -zeroD2 :: STy t -> Rep (D2 t) -zeroD2 typ = case typ of - STNil -> () - STPair _ _ -> Nothing - STEither _ _ -> Nothing - STMaybe _ -> Nothing - STArr _ _ -> Nothing - STScal sty -> case sty of - STI32 -> () - STI64 -> () +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) + SMTLEither _ _ -> Nothing + SMTMaybe _ -> Nothing + SMTArr _ t -> arrayMap (zeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 - STBool -> () - STAccum{} -> error "Zero of Accum" -addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -addD2s typ a b = case typ of - STNil -> () - STPair t1 t2 -> case (a, b) of +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of + SMTNil -> () + SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) + SMTLEither t1 t2 -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a - (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) - STEither t1 t2 -> case (a, b) of + (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) + _ -> error "Plus of inconsistent LEithers" + SMTMaybe t -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a - (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) - (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) - _ -> error "Plus of inconsistent Eithers" - STMaybe t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 -> Just y - | shapeSize sh2 == 0 -> Just x - | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) - | otherwise -> error "Plus of inconsistently shaped arrays" - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> a + b - STF64 -> a + b - STBool -> () - STAccum{} -> error "Plus of Accum" - -onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) -onehotD2 SAPHere _ _ val = val -onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) -onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) -onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) -onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) -onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj) (STArr n a) idx val = - Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx - -withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) + (Just x, Just y) -> Just (addM t x y) + SMTArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + SMTScal sty -> numericIsNum sty $ a + b + +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx + +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - accum <- newAcSparse t SAPHere () initval + accum <- newAcDense t initval out <- unAcM $ f accum val <- readAcSparse t accum return (out, val) -newAcZero :: STy t -> IO (RepAc t) -newAcZero = \case - STNil -> return () - STPair{} -> newIORef Nothing - STEither{} -> newIORef Nothing - STMaybe _ -> newIORef Nothing - STArr _ _ -> newIORef Nothing - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef 0.0 - STF64 -> newIORef 0.0 - STBool -> return () - STAccum{} -> error "Nested accumulators" - -newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) +newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) +newAcZero typ zi = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi + SMTLEither{} -> newIORef Nothing + SMTMaybe _ -> newIORef Nothing + SMTArr _ t -> arrayMapM (newAcZero t) zi + SMTScal sty -> numericIsNum sty $ newIORef 0 + +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val + SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val + SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val + SMTArr _ t1 -> arrayMapM (newAcDense t1) val + SMTScal _ -> newIORef val + +newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) newAcSparse typ prj idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val - (STScal sty, SAPHere) -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef val - STF64 -> newIORef val - STBool -> return () - - (STPair t1 t2, SAPFst prj') -> - newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 - (STPair t1 t2, SAPSnd prj') -> - newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val - - (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val + (_, SAPHere) -> newAcDense typ val - (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val + (SMTPair t1 t2, SAPFst prj') -> + (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) + (SMTPair t1 t2, SAPSnd prj') -> + (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val - (STArr n t, SAPArrIdx prj') -> newIORef . Just =<< newAcArray n t prj' idx val + (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val + (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - (STAccum{}, _) -> error "Accumulators not allowed in source program" + (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val -newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) -newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx + (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx onehotArray :: Monad m => (Rep (AcIdx p a) -> m v) -- ^ the "one" - -> m v -- ^ the "zero" + -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' + arrsh = arrayShape ziarr !linindex = toLinearIndex arrsh arrindex - in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) +readAcSparse :: SMTy t -> RepAc t -> IO (Rep t) readAcSparse typ val = case typ of - STNil -> return () - STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> readIORef val - STF64 -> readIORef val - STBool -> return () - STAccum{} -> error "Nested accumulators" - -accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val + SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val + SMTArr _ t -> traverse (readAcSparse t) val + SMTScal _ -> readIORef val + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () accumAddSparse typ prj ref idx val = case (typ, prj) of (STNil, SAPHere) -> return () @@ -348,8 +317,8 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of case val of Nothing -> return () Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 - <*> newAcSparse t2 SAPHere () val2) + realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 + <*> newAcDense t2 val2) (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 accumAddSparse t2 SAPHere ac2 () val2) (STPair t1 t2, SAPFst prj') -> @@ -386,7 +355,7 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of Nothing -> return () Just val' -> realiseMaybeSparse ref - (arrayMapM (newAcSparse t1 SAPHere ()) val') + (arrayMapM (newAcDense t1) val') (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) (STArr n t1, SAPArrIdx prj') -> -- cgit v1.2.3-70-g09d2