diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 56ebf82..bb4952c 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -30,7 +30,6 @@ import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace -import GHC.Stack import Array import AST @@ -250,7 +249,6 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - putStrLn $ "withAccum: " ++ show t accum <- newAcSparse t SZ () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum @@ -324,7 +322,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n piindexConcat PIIxEnd ix = ix piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) @@ -372,13 +370,19 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do go mk dep' dim' idx' val' $ \arr pish -> k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) newAcDense typ SZ () val = case typ of + STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 SZ () x Right y -> Right <$> newAcSparse t2 SZ () y _ -> error "newAcDense: invalid dense type" newAcDense typ (SS dep) idx val = case typ of + STPair t1 t2 -> + case (idx, val) of + (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 + (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' + _ -> error "Index/value mismatch in newAc pair" STEither t1 t2 -> case (idx, val) of (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' @@ -400,6 +404,7 @@ readAcSparse typ val = case typ of readAcDense :: STy t -> RepAcDense t -> IO (Rep t) readAcDense typ val = case typ of + STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) STEither t1 t2 -> case val of Left x -> Left <$> readAcSparse t1 x Right y -> Right <$> readAcSparse t2 y @@ -505,19 +510,27 @@ accumAddSparse typ (SS dep) ref idx val = case typ of accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) accumAddDense typ SZ ref () val = case typ of + STPair t1 t2 -> + (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) + accumAddSparse t2 SZ (snd ref) () (snd val)) STEither t1 t2 -> case (ref, val) of (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') - _ -> error "Mismatched Either in accumAdd" + _ -> error "Mismatched Either in accumAddDense either" _ -> error "accumAddDense: invalid dense type" accumAddDense typ (SS dep) ref idx val = case typ of + STPair t1 t2 -> + case (idx, val) of + (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val') + (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val') + _ -> error "Mismatched Either in accumAddDense pair" STEither t1 t2 -> case (ref, idx, val) of (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') - _ -> error "Mismatched Either in accumAdd" + _ -> error "Mismatched Either in accumAddDense either" _ -> error "accumAddDense: invalid dense type" |