diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-04-28 12:05:06 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-04-28 12:05:06 +0200 | 
| commit | 919a36f8eed21501357185a90e2b7a4d9eaf7f08 (patch) | |
| tree | 3c6c975ea9f51c8ad46105e4cbba08d9c7f77003 /src/Interpreter.hs | |
| parent | b1664532eaebdf0409ab6d93fc0ba2ef8dfbf372 (diff) | |
WIP interpreter support for new monoidal accumulators
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 201 | 
1 files changed, 85 insertions, 116 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index f8e7e98..af11de8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -164,15 +164,16 @@ interpret'Rec env = \case      accum <- interpret' env e3      accumAddSparse t p accum idx val    EZero _ t ezi -> do -    return $ zeroD2 t ezi +    zi <- interpret' env ezi +    return $ zeroM t zi    EPlus _ t a b -> do      a' <- interpret' env a      b' <- interpret' env b -    return $ addD2s t a' b' +    return $ addM t a' b'    EOneHot _ t p a b -> do      a' <- interpret' env a      b' <- interpret' env b -    return $ onehotD2 p t a' b' +    return $ onehotM p t a' b'    EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s  interpretOp :: SOp a t -> Rep a -> Rep t @@ -202,145 +203,113 @@ interpretOp op arg = case op of      styIsEq STF64 = id      styIsEq STBool = id -zeroD2 :: STy t -> Rep (D2 t) -zeroD2 typ = case typ of -  STNil -> () -  STPair _ _ -> Nothing -  STEither _ _ -> Nothing -  STMaybe _ -> Nothing -  STArr _ _ -> Nothing -  STScal sty -> case sty of -                  STI32 -> () -                  STI64 -> () +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of +  SMTNil -> () +  SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) +  SMTLEither _ _ -> Nothing +  SMTMaybe _ -> Nothing +  SMTArr _ t -> arrayMap (zeroM t) zi +  SMTScal sty -> case sty of +                  STI32 -> 0 +                  STI64 -> 0                    STF32 -> 0.0                    STF64 -> 0.0 -                  STBool -> () -  STAccum{} -> error "Zero of Accum" -addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -addD2s typ a b = case typ of -  STNil -> () -  STPair t1 t2 -> case (a, b) of +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of +  SMTNil -> () +  SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) +  SMTLEither t1 t2 -> case (a, b) of      (Nothing, _) -> b      (_, Nothing) -> a -    (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) -  STEither t1 t2 -> case (a, b) of +    (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) +    (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) +    _ -> error "Plus of inconsistent LEithers" +  SMTMaybe t -> case (a, b) of      (Nothing, _) -> b      (_, Nothing) -> a -    (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) -    (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) -    _ -> error "Plus of inconsistent Eithers" -  STMaybe t -> case (a, b) of -    (Nothing, _) -> b -    (_, Nothing) -> a -    (Just x, Just y) -> Just (addD2s t x y) -  STArr _ t -> case (a, b) of -    (Nothing, _) -> b -    (_, Nothing) -> a -    (Just x, Just y) -> -      let sh1 = arrayShape x -          sh2 = arrayShape y -      in if | shapeSize sh1 == 0 -> Just y -            | shapeSize sh2 == 0 -> Just x -            | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) -            | otherwise -> error "Plus of inconsistently shaped arrays" -  STScal sty -> case sty of -    STI32 -> () -    STI64 -> () -    STF32 -> a + b -    STF64 -> a + b -    STBool -> () -  STAccum{} -> error "Plus of Accum" +    (Just x, Just y) -> Just (addM t x y) +  SMTArr _ t -> +    let sh1 = arrayShape a +        sh2 = arrayShape b +    in if | shapeSize sh1 == 0 -> b +          | shapeSize sh2 == 0 -> a +          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) +          | otherwise -> error "Plus of inconsistently shaped arrays" +  SMTScal sty -> numericIsNum sty $ a + b -onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) -onehotD2 SAPHere _ _ val = val -onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) -onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) -onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) -onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) -onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj) (STArr n a) idx val = -  Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = +  runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx -withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)  withAccum t _ initval f = AcM $ do -  accum <- newAcSparse t SAPHere () initval +  accum <- newAcDense t initval    out <- unAcM $ f accum    val <- readAcSparse t accum    return (out, val) -newAcZero :: STy t -> IO (RepAc t) -newAcZero = \case -  STNil -> return () -  STPair{} -> newIORef Nothing -  STEither{} -> newIORef Nothing -  STMaybe _ -> newIORef Nothing -  STArr _ _ -> newIORef Nothing -  STScal sty -> case sty of -    STI32 -> return () -    STI64 -> return () -    STF32 -> newIORef 0.0 -    STF64 -> newIORef 0.0 -    STBool -> return () -  STAccum{} -> error "Nested accumulators" - -newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of -  (STNil, SAPHere) -> return () -  (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val -  (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val -  (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val -  (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val -  (STScal sty, SAPHere) -> case sty of -    STI32 -> return () -    STI64 -> return () -    STF32 -> newIORef val -    STF64 -> newIORef val -    STBool -> return () +newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) +newAcZero typ zi = case typ of +  SMTNil -> return () +  SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi +  SMTLEither{} -> newIORef Nothing +  SMTMaybe _ -> newIORef Nothing +  SMTArr _ t -> arrayMapM (newAcZero t) zi +  SMTScal sty -> numericIsNum sty $ newIORef 0 -  (STPair t1 t2, SAPFst prj') -> -    newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 -  (STPair t1 t2, SAPSnd prj') -> -    newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of +  SMTNil -> return () +  SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val +  SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val +  SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val +  SMTArr _ t1 -> arrayMapM (newAcDense t1) val +  SMTScal _ -> newIORef val -  (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val -  (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val +newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) +newAcSparse typ prj idx val = case (typ, prj) of +  (_, SAPHere) -> newAcDense typ val -  (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val +  (SMTPair t1 t2, SAPFst prj') -> +    (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) +  (SMTPair t1 t2, SAPSnd prj') -> +    (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val -  (STArr n t, SAPArrIdx prj') -> newIORef . Just =<< newAcArray n t prj' idx val +  (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val +  (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val -  (STAccum{}, _) -> error "Accumulators not allowed in source program" +  (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val -newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) -newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx +  (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx  onehotArray :: Monad m              => (Rep (AcIdx p a) -> m v)  -- ^ the "one" -            -> m v  -- ^ the "zero" +            -> (Rep (ZeroInfo a) -> m v)  -- ^ the "zero"              -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =    let arrindex = unTupRepIdx IxNil IxCons n arrindex' -      arrsh = unTupRepIdx ShNil ShCons n arrsh' +      arrsh = arrayShape ziarr        !linindex = toLinearIndex arrsh arrindex -  in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) +  in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) +readAcSparse :: SMTy t -> RepAc t -> IO (Rep t)  readAcSparse typ val = case typ of -  STNil -> return () -  STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val -  STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val -  STMaybe t -> traverse (readAcSparse t) =<< readIORef val -  STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val -  STScal sty -> case sty of -    STI32 -> return () -    STI64 -> return () -    STF32 -> readIORef val -    STF64 -> readIORef val -    STBool -> return () -  STAccum{} -> error "Nested accumulators" +  SMTNil -> return () +  SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val +  SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val +  SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val +  SMTArr _ t -> traverse (readAcSparse t) val +  SMTScal _ -> readIORef val -accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()  accumAddSparse typ prj ref idx val = case (typ, prj) of    (STNil, SAPHere) -> return () @@ -348,8 +317,8 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of      case val of        Nothing -> return ()        Just (val1, val2) -> -        realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 -                                    <*> newAcSparse t2 SAPHere () val2) +        realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 +                                    <*> newAcDense t2 val2)                                 (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1                                                    accumAddSparse t2 SAPHere ac2 () val2)    (STPair t1 t2, SAPFst prj') -> @@ -386,7 +355,7 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of        Nothing -> return ()        Just val' ->          realiseMaybeSparse ref -          (arrayMapM (newAcSparse t1 SAPHere ()) val') +          (arrayMapM (newAcDense t1) val')            (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->                      accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))    (STArr n t1, SAPArrIdx prj') -> | 
