diff options
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 448 |
1 files changed, 0 insertions, 448 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs deleted file mode 100644 index 58d79a5..0000000 --- a/src/Interpreter.hs +++ /dev/null @@ -1,448 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Interpreter ( - interpret, - interpretOpen, - Value(..), -) where - -import Control.Monad (foldM, join, when, forM_) -import Data.Bitraversable (bitraverse) -import Data.Char (isSpace) -import Data.Functor.Identity -import qualified Data.Functor.Product as Product -import Data.Int (Int64) -import Data.IORef -import System.IO (hPutStrLn, stderr) -import System.IO.Unsafe (unsafePerformIO) - -import Debug.Trace - -import Array -import AST -import AST.Pretty -import CHAD.Types -import Data -import Interpreter.Rep - - -newtype AcM s a = AcM { unAcM :: IO a } - deriving newtype (Functor, Applicative, Monad) - -runAcM :: (forall s. AcM s a) -> a -runAcM (AcM m) = unsafePerformIO m - -acmDebugLog :: String -> AcM s () -acmDebugLog s = AcM (hPutStrLn stderr s) - -data V t = V (STy t) (Rep t) - -interpret :: Ex '[] t -> Rep t -interpret = interpretOpen False SNil SNil - --- | Bool: whether to trace execution with debug prints (very verbose) -interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t -interpretOpen prints env venv e = - runAcM $ - let ?depth = 0 - ?prints = prints - in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e - -interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) - => SList V env -> Ex env t -> AcM s (Rep t) -interpret' env e = do - let tenv = slistMap (\(V t _) -> t) env - let dep = ?depth - let lenlimit = max 20 (100 - dep) - let replace a b = map (\c -> if c == a then b else c) - let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." - | otherwise = replace '\n' ' ' s - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) - res <- let ?depth = dep + 1 in interpret'Rec env e - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" - return res - -interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) -interpret'Rec env = \case - EVar _ _ i -> case slistIdx env i of V _ x -> return x - ELet _ a b -> do - x <- interpret' env a - let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b - expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined - EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b - EFst _ e -> fst <$> interpret' env e - ESnd _ e -> snd <$> interpret' env e - ENil _ -> return () - EInl _ _ e -> Left <$> interpret' env e - EInr _ _ e -> Right <$> interpret' env e - ECase _ e a b -> - let STEither t1 t2 = typeOf e - in interpret' env e >>= \case - Left x -> interpret' (V t1 x `SCons` env) a - Right y -> interpret' (V t2 y `SCons` env) b - ENothing _ _ -> return Nothing - EJust _ e -> Just <$> interpret' env e - EMaybe _ a b e -> - let STMaybe t1 = typeOf e - in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e - EConstArr _ _ _ v -> return v - EBuild _ dim a b -> do - sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) - EFold1Inner _ _ a b c -> do - let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a - x0 <- interpret' env b - arr <- interpret' env c - let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - ESum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) - EReplicate1Inner _ a b -> do - n <- fromIntegral @Int64 @Int <$> interpret' env a - arr <- interpret' env b - let sh = arrayShape arr - return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) - EMaximum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ - arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) - EMinimum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ - arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) - EConst _ _ v -> return v - EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e - EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ a b - | STArr n _ <- typeOf a - -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) - EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e - EOp _ op e -> interpretOp op <$> interpret' env e - ECustom _ t1 t2 _ pr _ _ e1 e2 -> do - e1' <- interpret' env e1 - e2' <- interpret' env e2 - interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr - EWith _ t e1 e2 -> do - initval <- interpret' env e1 - withAccum t (typeOf e2) initval $ \accum -> - interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do - idx <- interpret' env e1 - val <- interpret' env e2 - accum <- interpret' env e3 - accumAddSparse t p accum idx val - EZero _ t -> do - return $ zeroD2 t - EPlus _ t a b -> do - a' <- interpret' env a - b' <- interpret' env b - return $ addD2s t a' b' - EOneHot _ t p a b -> do - a' <- interpret' env a - b' <- interpret' env b - return $ onehotD2 p t a' b' - EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s - -interpretOp :: SOp a t -> Rep a -> Rep t -interpretOp op arg = case op of - OAdd st -> numericIsNum st $ uncurry (+) arg - OMul st -> numericIsNum st $ uncurry (*) arg - ONeg st -> numericIsNum st $ negate arg - OLt st -> numericIsNum st $ uncurry (<) arg - OLe st -> numericIsNum st $ uncurry (<=) arg - OEq st -> styIsEq st $ uncurry (==) arg - ONot -> not arg - OAnd -> uncurry (&&) arg - OOr -> uncurry (||) arg - OIf -> if arg then Left () else Right () - ORound64 -> round arg - OToFl64 -> fromIntegral arg - ORecip st -> floatingIsFractional st $ recip arg - OExp st -> floatingIsFractional st $ exp arg - OLog st -> floatingIsFractional st $ log arg - OIDiv st -> integralIsIntegral st $ uncurry quot arg - OMod st -> integralIsIntegral st $ uncurry rem arg - where - styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r - styIsEq STI32 = id - styIsEq STI64 = id - styIsEq STF32 = id - styIsEq STF64 = id - styIsEq STBool = id - -zeroD2 :: STy t -> Rep (D2 t) -zeroD2 typ = case typ of - STNil -> () - STPair _ _ -> Nothing - STEither _ _ -> Nothing - STMaybe _ -> Nothing - STArr _ _ -> Nothing - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> 0.0 - STF64 -> 0.0 - STBool -> () - STAccum{} -> error "Zero of Accum" - -addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -addD2s typ a b = case typ of - STNil -> () - STPair t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) - STEither t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) - (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) - _ -> error "Plus of inconsistent Eithers" - STMaybe t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 -> Just y - | shapeSize sh2 == 0 -> Just x - | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) - | otherwise -> error "Plus of inconsistently shaped arrays" - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> a + b - STF64 -> a + b - STBool -> () - STAccum{} -> error "Plus of Accum" - -onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) -onehotD2 SAPHere _ _ val = val -onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) -onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) -onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) -onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) -onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = - Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx - -withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) -withAccum t _ initval f = AcM $ do - accum <- newAcSparse t SAPHere () initval - out <- unAcM $ f accum - val <- readAcSparse t accum - return (out, val) - -newAcZero :: STy t -> IO (RepAc t) -newAcZero = \case - STNil -> return () - STPair{} -> newIORef Nothing - STEither{} -> newIORef Nothing - STMaybe _ -> newIORef Nothing - STArr _ _ -> newIORef Nothing - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef 0.0 - STF64 -> newIORef 0.0 - STBool -> return () - STAccum{} -> error "Nested accumulators" - -newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val - (STScal sty, SAPHere) -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef val - STF64 -> newIORef val - STBool -> return () - - (STPair t1 t2, SAPFst prj') -> - newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 - (STPair t1 t2, SAPSnd prj') -> - newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val - - (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val - - (STAccum{}, _) -> error "Accumulators not allowed in source program" - -newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) -newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx - -onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" - -> m v -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = - let arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' - !linindex = toLinearIndex arrsh arrindex - in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) - -readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) -readAcSparse typ val = case typ of - STNil -> return () - STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> readIORef val - STF64 -> readIORef val - STBool -> return () - STAccum{} -> error "Nested accumulators" - -accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - - (STPair t1 t2, SAPHere) -> - case val of - Nothing -> return () - Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 - <*> newAcSparse t2 SAPHere () val2) - (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 - accumAddSparse t2 SAPHere ac2 () val2) - (STPair t1 t2, SAPFst prj') -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) - (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) - (STPair t1 t2, SAPSnd prj') -> - realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) - (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - - (STEither{}, SAPHere) -> - case val of - Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - (STEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (STEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - - (STMaybe{}, SAPHere) -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - (STMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) - - (STArr _ t1, SAPHere) -> - case val of - Nothing -> return () - Just val' -> - realiseMaybeSparse ref - (arrayMapM (newAcSparse t1 SAPHere ()) val') - (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> - accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj' _) -> - let ((arrindex', arrsh'), idx') = idx - arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' - linindex = toLinearIndex arrsh arrindex - in realiseMaybeSparse ref - (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) - (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - - (STScal sty, SAPHere) -> AcM $ case sty of - STI32 -> return () - STI64 -> return () - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> return () - - (STAccum{}, _) -> error "Accumulators not allowed in source program" - -realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () -realiseMaybeSparse ref makeval modifyval = - -- Try modifying what's already in ref. The 'join' makes the snd - -- of the function's return value a _continuation_ that is run after - -- the critical section ends. - AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (ac, do val <- makeval - join $ atomicModifyIORef' ref $ \ac' -> case ac' of - Nothing -> (Just val, return ()) - Just val' -> (ac', unAcM $ modifyval val')) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just val -> (ac, unAcM $ modifyval val) - - -numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r -numericIsNum STI32 = id -numericIsNum STI64 = id -numericIsNum STF32 = id -numericIsNum STF64 = id - -floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r -floatingIsFractional STF32 = id -floatingIsFractional STF64 = id - -integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r -integralIsIntegral STI32 = id -integralIsIntegral STI64 = id - -unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) - -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n -unTupRepIdx nil _ SZ _ = nil -unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i - -tupRepIdx :: (forall m. f (S m) -> (f m, Int)) - -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) -tupRepIdx _ SZ _ = () -tupRepIdx uncons (SS n) tup = - let (tup', i) = uncons tup - in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i - -ixUncons :: Index (S n) -> (Index n, Int) -ixUncons (IxCons idx i) = (idx, i) - -shUncons :: Shape (S n) -> (Shape n, Int) -shUncons (ShCons idx i) = (idx, i) |
