diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Interpreter.hs | 209 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 20 |
2 files changed, 94 insertions, 135 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index f8e7e98..af11de8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -164,15 +164,16 @@ interpret'Rec env = \case accum <- interpret' env e3 accumAddSparse t p accum idx val EZero _ t ezi -> do - return $ zeroD2 t ezi + zi <- interpret' env ezi + return $ zeroM t zi EPlus _ t a b -> do a' <- interpret' env a b' <- interpret' env b - return $ addD2s t a' b' + return $ addM t a' b' EOneHot _ t p a b -> do a' <- interpret' env a b' <- interpret' env b - return $ onehotD2 p t a' b' + return $ onehotM p t a' b' EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -202,145 +203,113 @@ interpretOp op arg = case op of styIsEq STF64 = id styIsEq STBool = id -zeroD2 :: STy t -> Rep (D2 t) -zeroD2 typ = case typ of - STNil -> () - STPair _ _ -> Nothing - STEither _ _ -> Nothing - STMaybe _ -> Nothing - STArr _ _ -> Nothing - STScal sty -> case sty of - STI32 -> () - STI64 -> () +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) + SMTLEither _ _ -> Nothing + SMTMaybe _ -> Nothing + SMTArr _ t -> arrayMap (zeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 STF32 -> 0.0 STF64 -> 0.0 - STBool -> () - STAccum{} -> error "Zero of Accum" -addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -addD2s typ a b = case typ of - STNil -> () - STPair t1 t2 -> case (a, b) of +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of + SMTNil -> () + SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) + SMTLEither t1 t2 -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a - (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) - STEither t1 t2 -> case (a, b) of + (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) + _ -> error "Plus of inconsistent LEithers" + SMTMaybe t -> case (a, b) of (Nothing, _) -> b (_, Nothing) -> a - (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) - (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) - _ -> error "Plus of inconsistent Eithers" - STMaybe t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 -> Just y - | shapeSize sh2 == 0 -> Just x - | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) - | otherwise -> error "Plus of inconsistently shaped arrays" - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> a + b - STF64 -> a + b - STBool -> () - STAccum{} -> error "Plus of Accum" - -onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) -onehotD2 SAPHere _ _ val = val -onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) -onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) -onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) -onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) -onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj) (STArr n a) idx val = - Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx - -withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) + (Just x, Just y) -> Just (addM t x y) + SMTArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + SMTScal sty -> numericIsNum sty $ a + b + +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdx p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx + +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - accum <- newAcSparse t SAPHere () initval + accum <- newAcDense t initval out <- unAcM $ f accum val <- readAcSparse t accum return (out, val) -newAcZero :: STy t -> IO (RepAc t) -newAcZero = \case - STNil -> return () - STPair{} -> newIORef Nothing - STEither{} -> newIORef Nothing - STMaybe _ -> newIORef Nothing - STArr _ _ -> newIORef Nothing - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef 0.0 - STF64 -> newIORef 0.0 - STBool -> return () - STAccum{} -> error "Nested accumulators" - -newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) +newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) +newAcZero typ zi = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi + SMTLEither{} -> newIORef Nothing + SMTMaybe _ -> newIORef Nothing + SMTArr _ t -> arrayMapM (newAcZero t) zi + SMTScal sty -> numericIsNum sty $ newIORef 0 + +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val + SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val + SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val + SMTArr _ t1 -> arrayMapM (newAcDense t1) val + SMTScal _ -> newIORef val + +newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep b -> IO (RepAc a) newAcSparse typ prj idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val - (STScal sty, SAPHere) -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef val - STF64 -> newIORef val - STBool -> return () - - (STPair t1 t2, SAPFst prj') -> - newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 - (STPair t1 t2, SAPSnd prj') -> - newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val - - (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val + (_, SAPHere) -> newAcDense typ val - (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val + (SMTPair t1 t2, SAPFst prj') -> + (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx) + (SMTPair t1 t2, SAPSnd prj') -> + (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val - (STArr n t, SAPArrIdx prj') -> newIORef . Just =<< newAcArray n t prj' idx val + (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val + (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - (STAccum{}, _) -> error "Accumulators not allowed in source program" + (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val -newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) -newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx + (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx onehotArray :: Monad m => (Rep (AcIdx p a) -> m v) -- ^ the "one" - -> m v -- ^ the "zero" + -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = let arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' + arrsh = arrayShape ziarr !linindex = toLinearIndex arrsh arrindex - in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) +readAcSparse :: SMTy t -> RepAc t -> IO (Rep t) readAcSparse typ val = case typ of - STNil -> return () - STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> readIORef val - STF64 -> readIORef val - STBool -> return () - STAccum{} -> error "Nested accumulators" - -accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val + SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val + SMTArr _ t -> traverse (readAcSparse t) val + SMTScal _ -> readIORef val + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () accumAddSparse typ prj ref idx val = case (typ, prj) of (STNil, SAPHere) -> return () @@ -348,8 +317,8 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of case val of Nothing -> return () Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 - <*> newAcSparse t2 SAPHere () val2) + realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 + <*> newAcDense t2 val2) (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 accumAddSparse t2 SAPHere ac2 () val2) (STPair t1 t2, SAPFst prj') -> @@ -386,7 +355,7 @@ accumAddSparse typ prj ref idx val = case (typ, prj) of Nothing -> return () Just val' -> realiseMaybeSparse ref - (arrayMapM (newAcSparse t1 SAPHere ()) val') + (arrayMapM (newAcDense t1) val') (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) (STArr n t1, SAPArrIdx prj') -> diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 9056901..1226b0c 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -6,7 +6,6 @@ module Interpreter.Rep where import Data.List (intersperse, intercalate) import Data.Foldable (toList) import Data.IORef -import GHC.TypeError import Array import AST @@ -24,23 +23,14 @@ type family Rep t where Rep (TAccum t) = RepAc t Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) --- Mutable, represents D2 of t. Has an O(1) zero. +-- Mutable, represents monoid types t. type family RepAc t where RepAc TNil = () - RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b)) - RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) - RepAc (TMaybe t) = IORef (Maybe (RepAc t)) - RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t))) - RepAc (TScal sty) = RepAcScal sty - RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") + RepAc (TPair a b) = (RepAc a, RepAc b) RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) - -type family RepAcScal t where - RepAcScal TI32 = () - RepAcScal TI64 = () - RepAcScal TF32 = IORef Float - RepAcScal TF64 = IORef Double - RepAcScal TBool = () + RepAc (TMaybe t) = IORef (Maybe (RepAc t)) + RepAc (TArr n t) = Array n (RepAc t) + RepAc (TScal sty) = IORef (ScalRep sty) newtype Value t = Value { unValue :: Rep t } |