diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 267 |
1 files changed, 155 insertions, 112 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d80a76e..11caac0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,7 @@ module Interpreter ( import Control.Monad (foldM, join, when) import Data.Bifunctor (bimap) +import Data.Bitraversable (bitraverse) import Data.Char (isSpace) import Data.Functor.Identity import Data.Kind (Type) @@ -134,26 +135,25 @@ interpret'Rec env = \case e1' <- interpret' env e1 e2' <- interpret' env e2 interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr - EWith _ e1 e2 -> do + EWith _ t e1 e2 -> do initval <- interpret' env e1 - withAccum (typeOf e1) (typeOf e2) initval $ \accum -> + withAccum t (typeOf e2) initval $ \accum -> interpret' (Value accum `SCons` env) e2 - EAccum _ i e1 e2 e3 -> do - let STAccum t = typeOf e3 + EAccum _ t p e1 e2 e3 -> do idx <- interpret' env e1 val <- interpret' env e2 accum <- interpret' env e3 - accumAddSparse t i accum idx val + accumAddSparse t p accum idx val EZero _ t -> do return $ zeroD2 t EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b return $ addD2s t a' b' - EOneHot _ t i a b -> do + EOneHot _ t p a b -> do a' <- interpret' env a b' <- interpret' env b - return $ onehotD2 i t a' b' + return $ onehotD2 p t a' b' EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -230,44 +230,37 @@ addD2s typ a b = case typ of STBool -> () STAccum{} -> error "Plus of Accum" -onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t) -onehotD2 SZ _ () v = v -onehotD2 _ STNil _ _ = () -onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val -onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value" -onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val) -onehotD2 (SS i ) (STArr n t) idx val = runIdentity $ - onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val -onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar" -onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" - -withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) +onehotD2 SAPHere _ _ val = val +onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) +onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) +onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) +onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) +onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) +onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx + +withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) withAccum t _ initval f = AcM $ do - accum <- newAcSparse t SZ () initval + accum <- newAcSparse t SAPHere () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum return (out, val) -newAcZero :: STy t -> IO (RepAcSparse t) +newAcZero :: STy t -> IO (RepAc t) newAcZero = \case STNil -> return () - STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2 + STPair{} -> newIORef Nothing + STEither{} -> newIORef Nothing STMaybe _ -> newIORef Nothing STArr n _ -> newIORef (emptyArray n) STScal sty -> case sty of - STI32 -> newIORef 0 - STI64 -> newIORef 0 + STI32 -> return () + STI64 -> return () STF32 -> newIORef 0.0 STF64 -> newIORef 0.0 - STBool -> error "Accumulator of Bool" + STBool -> return () STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" -- | Inverted index: the outermost index is at the /outside/ of this list. data PartialInvIndex n m where @@ -322,95 +315,144 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n piindexConcat PIIxEnd ix = ix piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) -newAcSparse typ SZ () val = case typ of - STNil -> return () - STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) - STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val - STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val - STScal{} -> newIORef val - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" -newAcSparse typ (SS dep) idx val = case typ of - STNil -> return () - STPair t1 t2 -> newIORef =<< case (idx, val) of - (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 - (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' - _ -> error "Index/value mismatch in newAc pair" - STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val - STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" +newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) +newAcSparse typ prj idx val = case (typ, prj) of + (STNil, SAPHere) -> return () + (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val + (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val + (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val + (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val + (STScal sty, SAPHere) -> case sty of + STI32 -> return () + STI64 -> return () + STF32 -> newIORef val + STF64 -> newIORef val + STBool -> return () -newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) -newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n + (STPair t1 t2, SAPFst prj') -> + newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 + (STPair t1 t2, SAPSnd prj') -> + newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val -onehotArray :: Monad m - => STy t - -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -- ^ the "one" - -> m v -- ^ generate a zero value for elsewhere - -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v) -onehotArray _ mkone _ _ SZ _ val = - traverse (mkone SZ ()) val -onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do - let sh = unTupRepIdx ShNil ShCons dim (fst val) - go mkone dep dim idx (snd val) $ \arr position -> - arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of - Just i' -> return $ arr `arrayIndex` i' - Nothing -> mkzero) - where - go :: Monad m - => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) - -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) - -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r - go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd - go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd - go mk (SS dep') (SS dim') (i, idx') val' k = - go mk dep' dim' idx' val' $ \arr pish -> - k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) - -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) -newAcDense typ SZ () val = case typ of - STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) - STEither t1 t2 -> case val of - Left x -> Left <$> newAcSparse t1 SZ () x - Right y -> Right <$> newAcSparse t2 SZ () y - _ -> error "newAcDense: invalid dense type" -newAcDense typ (SS dep) idx val = case typ of - STPair t1 t2 -> - case (idx, val) of - (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 - (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' - _ -> error "Index/value mismatch in newAc pair" - STEither t1 t2 -> - case (idx, val) of - (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' - (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' - _ -> error "Index/value mismatch in newAc either" - _ -> error "newAcDense: invalid dense type" + (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val + (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val + + (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val + + (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val + + (STAccum{}, _) -> error "Accumulators not allowed in source program" -readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) +newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) +newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx + +onehotArray :: Monad m + => (Rep (AcIdx p a) -> m v) -- ^ the "one" + -> m v -- ^ the "zero" + -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = + let arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = unTupRepIdx ShNil ShCons n arrsh' + in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero) + +-- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a)) +-- newAcDense typ SZ () val = case typ of +-- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) +-- STEither t1 t2 -> case val of +-- Left x -> Left <$> newAcSparse t1 SZ () x +-- Right y -> Right <$> newAcSparse t2 SZ () y +-- _ -> error "newAcDense: invalid dense type" +-- newAcDense typ (SS dep) idx val = case typ of +-- STPair t1 t2 -> +-- case (idx, val) of +-- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 +-- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' +-- _ -> error "Index/value mismatch in newAc pair" +-- STEither t1 t2 -> +-- case (idx, val) of +-- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' +-- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' +-- _ -> error "Index/value mismatch in newAc either" +-- _ -> error "newAcDense: invalid dense type" + +readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) readAcSparse typ val = case typ of STNil -> return () - STPair t1 t2 -> do - (a, b) <- readIORef val - (,) <$> readAcSparse t1 a <*> readAcSparse t2 b - STMaybe t -> traverse (readAcDense t) =<< readIORef val + STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val + STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val + STMaybe t -> traverse (readAcSparse t) =<< readIORef val STArr _ t -> traverse (readAcSparse t) =<< readIORef val - STScal{} -> readIORef val + STScal sty -> case sty of + STI32 -> return () + STI64 -> return () + STF32 -> readIORef val + STF64 -> readIORef val + STBool -> return () STAccum{} -> error "Nested accumulators" - STEither{} -> error "Bare Either in accumulator" -readAcDense :: STy t -> RepAcDense t -> IO (Rep t) -readAcDense typ val = case typ of - STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) - STEither t1 t2 -> case val of - Left x -> Left <$> readAcSparse t1 x - Right y -> Right <$> readAcSparse t2 y - _ -> error "readAcDense: invalid dense type" +-- readAcDense :: STy t -> RepAcDense t -> IO (Rep t) +-- readAcDense typ val = case typ of +-- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) +-- STEither t1 t2 -> case val of +-- Left x -> Left <$> readAcSparse t1 x +-- Right y -> Right <$> readAcSparse t2 y +-- _ -> error "readAcDense: invalid dense type" + +accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of + (STNil, SAPHere) -> return () -accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () + (STPair t1 t2, SAPHere) -> + case val of + Nothing -> return () + Just (val1, val2) -> + AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 + <*> newAcSparse t2 SAPHere () val2) + (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1 + unAcM $ accumAddSparse t2 SAPHere ac2 () val2) + (STPair t1 t2, SAPFst prj') -> + AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) + (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val) + (STPair t1 t2, SAPSnd prj') -> + AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) + (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val) + + (STEither t1 t2, SAPHere) -> _ ref val + (STEither t1 _, SAPLeft prj') -> _ ref idx val + (STEither _ t2, SAPRight prj') -> _ ref idx val + + (STMaybe t1, SAPHere) -> _ ref val + (STMaybe t1, SAPJust prj') -> _ ref idx val + + (STArr _ t1, SAPHere) -> _ ref val + (STArr n t, SAPArrIdx prj' _) -> _ ref idx val + + (STScal sty, SAPHere) -> AcM $ case sty of + STI32 -> return () + STI64 -> return () + STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) + STBool -> return () + + (STAccum{}, _) -> error "Accumulators not allowed in source program" + +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO () +realiseMaybeSparse ref makeval modifyval = + -- Try modifying what's already in ref. The 'join' makes the snd + -- of the function's return value a _continuation_ that is run after + -- the critical section ends. + join $ atomicModifyIORef' ref $ \ac -> case ac of + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (ac, do val <- makeval + join $ atomicModifyIORef' ref $ \ac' -> case ac' of + Nothing -> (Just val, return ()) + Just val' -> (ac', modifyval val')) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just val -> (ac, modifyval val) + +{- accumAddSparse typ SZ ref () val = case typ of STNil -> return () STPair t1 t2 -> AcM $ do @@ -532,6 +574,7 @@ accumAddDense typ (SS dep) ref idx val = case typ of (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') _ -> error "Mismatched Either in accumAddDense either" _ -> error "accumAddDense: invalid dense type" +-} numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r |