diff options
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 261 | 
1 files changed, 152 insertions, 109 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d80a76e..11caac0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,7 @@ module Interpreter (  import Control.Monad (foldM, join, when)  import Data.Bifunctor (bimap) +import Data.Bitraversable (bitraverse)  import Data.Char (isSpace)  import Data.Functor.Identity  import Data.Kind (Type) @@ -134,26 +135,25 @@ interpret'Rec env = \case      e1' <- interpret' env e1      e2' <- interpret' env e2      interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr -  EWith _ e1 e2 -> do +  EWith _ t e1 e2 -> do      initval <- interpret' env e1 -    withAccum (typeOf e1) (typeOf e2) initval $ \accum -> +    withAccum t (typeOf e2) initval $ \accum ->        interpret' (Value accum `SCons` env) e2 -  EAccum _ i e1 e2 e3 -> do -    let STAccum t = typeOf e3 +  EAccum _ t p e1 e2 e3 -> do      idx <- interpret' env e1      val <- interpret' env e2      accum <- interpret' env e3 -    accumAddSparse t i accum idx val +    accumAddSparse t p accum idx val    EZero _ t -> do      return $ zeroD2 t    EPlus _ t a b -> do      a' <- interpret' env a      b' <- interpret' env b      return $ addD2s t a' b' -  EOneHot _ t i a b -> do +  EOneHot _ t p a b -> do      a' <- interpret' env a      b' <- interpret' env b -    return $ onehotD2 i t a' b' +    return $ onehotD2 p t a' b'    EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s  interpretOp :: SOp a t -> Rep a -> Rep t @@ -230,44 +230,37 @@ addD2s typ a b = case typ of      STBool -> ()    STAccum{} -> error "Plus of Accum" -onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t) -onehotD2 SZ _ () v = v -onehotD2 _ STNil _ _ = () -onehotD2 (SS SZ    ) (STPair _  _ ) ()          val         = Just val -onehotD2 (SS (SS i)) (STPair t1 t2) (Left  idx) (Left  val) = Just (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _     ) (STPair _  _ ) _           _           = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS SZ    ) (STEither _  _ ) ()          val         = Just val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Left  idx) (Left  val) = Just (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _  t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _     ) (STEither _  _ ) _           _           = error "onehotD2: either: mismatched index and value" -onehotD2 (SS i     ) (STMaybe t) idx val = Just (onehotD2 i t idx val) -onehotD2 (SS i     ) (STArr n t) idx val = runIdentity $ -  onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val -onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar" -onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" +onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) +onehotD2 SAPHere _ _ val = val +onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) +onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) +onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) +onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) +onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) +onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = +  runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx -withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))  withAccum t _ initval f = AcM $ do -  accum <- newAcSparse t SZ () initval +  accum <- newAcSparse t SAPHere () initval    out <- case f accum of AcM m -> m    val <- readAcSparse t accum    return (out, val) -newAcZero :: STy t -> IO (RepAcSparse t) +newAcZero :: STy t -> IO (RepAc t)  newAcZero = \case    STNil -> return () -  STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2 +  STPair{} -> newIORef Nothing +  STEither{} -> newIORef Nothing    STMaybe _ -> newIORef Nothing    STArr n _ -> newIORef (emptyArray n)    STScal sty -> case sty of -    STI32 -> newIORef 0 -    STI64 -> newIORef 0 +    STI32 -> return () +    STI64 -> return ()      STF32 -> newIORef 0.0      STF64 -> newIORef 0.0 -    STBool -> error "Accumulator of Bool" +    STBool -> return ()    STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator"  -- | Inverted index: the outermost index is at the /outside/ of this list.  data PartialInvIndex n m where @@ -322,95 +315,144 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n  piindexConcat PIIxEnd ix = ix  piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) -newAcSparse typ SZ () val = case typ of -  STNil -> return () -  STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) -  STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val -  STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val -  STScal{} -> newIORef val -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" -newAcSparse typ (SS dep) idx val = case typ of -  STNil -> return () -  STPair t1 t2 -> newIORef =<< case (idx, val) of -                                 (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 -                                 (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' -                                 _ -> error "Index/value mismatch in newAc pair" -  STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val -  STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" +newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) +newAcSparse typ prj idx val = case (typ, prj) of +  (STNil, SAPHere) -> return () +  (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val +  (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val +  (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val +  (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val +  (STScal sty, SAPHere) -> case sty of +    STI32 -> return () +    STI64 -> return () +    STF32 -> newIORef val +    STF64 -> newIORef val +    STBool -> return () -newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) -newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n +  (STPair t1 t2, SAPFst prj') -> +    newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 +  (STPair t1 t2, SAPSnd prj') -> +    newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val + +  (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val +  (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val + +  (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val + +  (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val + +  (STAccum{}, _) -> error "Accumulators not allowed in source program" + +newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) +newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx  onehotArray :: Monad m -            => STy t -            -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v)  -- ^ the "one" -            -> m v  -- ^ generate a zero value for elsewhere -            -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v) -onehotArray _ mkone _ _ SZ _ val = -  traverse (mkone SZ ()) val -onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do -  let sh = unTupRepIdx ShNil ShCons dim (fst val) -  go mkone dep dim idx (snd val) $ \arr position -> -    arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of -                               Just i' -> return $ arr `arrayIndex` i' -                               Nothing -> mkzero) -  where -    go :: Monad m -       => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -       -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -       -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r -    go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd -    go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd -    go mk (SS dep') (SS dim') (i, idx') val' k = -      go mk dep' dim' idx' val' $ \arr pish -> -        k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) +            => (Rep (AcIdx p a) -> m v)  -- ^ the "one" +            -> m v  -- ^ the "zero" +            -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = +  let arrindex = unTupRepIdx IxNil IxCons n arrindex' +      arrsh = unTupRepIdx ShNil ShCons n arrsh' +  in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero) -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) -newAcDense typ SZ () val = case typ of -  STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) -  STEither t1 t2 -> case val of -    Left x -> Left <$> newAcSparse t1 SZ () x -    Right y -> Right <$> newAcSparse t2 SZ () y -  _ -> error "newAcDense: invalid dense type" -newAcDense typ (SS dep) idx val = case typ of -  STPair t1 t2 -> -    case (idx, val) of -      (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 -      (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' -      _ -> error "Index/value mismatch in newAc pair" -  STEither t1 t2 -> -    case (idx, val) of -      (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' -      (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' -      _ -> error "Index/value mismatch in newAc either" -  _ -> error "newAcDense: invalid dense type" +-- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a)) +-- newAcDense typ SZ () val = case typ of +--   STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) +--   STEither t1 t2 -> case val of +--     Left x -> Left <$> newAcSparse t1 SZ () x +--     Right y -> Right <$> newAcSparse t2 SZ () y +--   _ -> error "newAcDense: invalid dense type" +-- newAcDense typ (SS dep) idx val = case typ of +--   STPair t1 t2 -> +--     case (idx, val) of +--       (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 +--       (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' +--       _ -> error "Index/value mismatch in newAc pair" +--   STEither t1 t2 -> +--     case (idx, val) of +--       (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' +--       (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' +--       _ -> error "Index/value mismatch in newAc either" +--   _ -> error "newAcDense: invalid dense type" -readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) +readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))  readAcSparse typ val = case typ of    STNil -> return () -  STPair t1 t2 -> do -    (a, b) <- readIORef val -    (,) <$> readAcSparse t1 a <*> readAcSparse t2 b -  STMaybe t -> traverse (readAcDense t) =<< readIORef val +  STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val +  STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val +  STMaybe t -> traverse (readAcSparse t) =<< readIORef val    STArr _ t -> traverse (readAcSparse t) =<< readIORef val -  STScal{} -> readIORef val +  STScal sty -> case sty of +    STI32 -> return () +    STI64 -> return () +    STF32 -> readIORef val +    STF64 -> readIORef val +    STBool -> return ()    STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" -readAcDense :: STy t -> RepAcDense t -> IO (Rep t) -readAcDense typ val = case typ of -  STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) -  STEither t1 t2 -> case val of -    Left x -> Left <$> readAcSparse t1 x -    Right y -> Right <$> readAcSparse t2 y -  _ -> error "readAcDense: invalid dense type" +-- readAcDense :: STy t -> RepAcDense t -> IO (Rep t) +-- readAcDense typ val = case typ of +--   STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) +--   STEither t1 t2 -> case val of +--     Left x -> Left <$> readAcSparse t1 x +--     Right y -> Right <$> readAcSparse t2 y +--   _ -> error "readAcDense: invalid dense type" + +accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of +  (STNil, SAPHere) -> return () + +  (STPair t1 t2, SAPHere) -> +    case val of +      Nothing -> return () +      Just (val1, val2) -> +        AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 +                                          <*> newAcSparse t2 SAPHere () val2) +                                     (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1 +                                                        unAcM $ accumAddSparse t2 SAPHere ac2 () val2) +  (STPair t1 t2, SAPFst prj') -> +    AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) +                                 (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val) +  (STPair t1 t2, SAPSnd prj') -> +    AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) +                                 (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val) + +  (STEither t1 t2, SAPHere) -> _ ref val +  (STEither t1 _, SAPLeft prj') -> _ ref idx val +  (STEither _ t2, SAPRight prj') -> _ ref idx val + +  (STMaybe t1, SAPHere) -> _ ref val +  (STMaybe t1, SAPJust prj') -> _ ref idx val + +  (STArr _ t1, SAPHere) -> _ ref val +  (STArr n t, SAPArrIdx prj' _) -> _ ref idx val + +  (STScal sty, SAPHere) -> AcM $ case sty of +    STI32 -> return () +    STI64 -> return () +    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STBool -> return () + +  (STAccum{}, _) -> error "Accumulators not allowed in source program" + +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO () +realiseMaybeSparse ref makeval modifyval = +  -- Try modifying what's already in ref. The 'join' makes the snd +  -- of the function's return value a _continuation_ that is run after +  -- the critical section ends. +  join $ atomicModifyIORef' ref $ \ac -> case ac of +           -- Oops, ref's contents was still sparse. Have to initialise +           -- it first, then try again. +           Nothing -> (ac, do val <- makeval +                              join $ atomicModifyIORef' ref $ \ac' -> case ac' of +                                       Nothing -> (Just val, return ()) +                                       Just val' -> (ac', modifyval val')) +           -- Yep, ref already had a value in there, so we can just add +           -- val' to it recursively. +           Just val -> (ac, modifyval val) -accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +{-  accumAddSparse typ SZ ref () val = case typ of    STNil -> return ()    STPair t1 t2 -> AcM $ do @@ -532,6 +574,7 @@ accumAddDense typ (SS dep) ref idx val = case typ of        (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')        _ -> error "Mismatched Either in accumAddDense either"    _ -> error "accumAddDense: invalid dense type" +-}  numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r  | 
