summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs25
1 files changed, 19 insertions, 6 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 56ebf82..bb4952c 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -30,7 +30,6 @@ import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
-import GHC.Stack
import Array
import AST
@@ -250,7 +249,6 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- putStrLn $ "withAccum: " ++ show t
accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
@@ -324,7 +322,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
@@ -372,13 +370,19 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
go mk dep' dim' idx' val' $ \arr pish ->
k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
newAcDense typ SZ () val = case typ of
+ STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> newAcSparse t1 SZ () x
Right y -> Right <$> newAcSparse t2 SZ () y
_ -> error "newAcDense: invalid dense type"
newAcDense typ (SS dep) idx val = case typ of
+ STPair t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
+ (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
+ _ -> error "Index/value mismatch in newAc pair"
STEither t1 t2 ->
case (idx, val) of
(Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
@@ -400,6 +404,7 @@ readAcSparse typ val = case typ of
readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
readAcDense typ val = case typ of
+ STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> readAcSparse t1 x
Right y -> Right <$> readAcSparse t2 y
@@ -505,19 +510,27 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
accumAddDense typ SZ ref () val = case typ of
+ STPair t1 t2 ->
+ (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
+ accumAddSparse t2 SZ (snd ref) () (snd val))
STEither t1 t2 ->
case (ref, val) of
(Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val')
(Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val')
- _ -> error "Mismatched Either in accumAdd"
+ _ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
accumAddDense typ (SS dep) ref idx val = case typ of
+ STPair t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val')
+ (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val')
+ _ -> error "Mismatched Either in accumAddDense pair"
STEither t1 t2 ->
case (ref, idx, val) of
(Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val')
(Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
- _ -> error "Mismatched Either in accumAdd"
+ _ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"