summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/AST/Pretty.hs2
-rw-r--r--src/Interpreter.hs25
-rw-r--r--src/Interpreter/Rep.hs2
3 files changed, 21 insertions, 8 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index ec8574f..663e9b0 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -218,7 +218,7 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
- EZero _ -> return $ showString "zero"
+ EZero t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")")
EPlus _ a b -> do
a' <- ppExpr' 11 val a
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 56ebf82..bb4952c 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -30,7 +30,6 @@ import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
-import GHC.Stack
import Array
import AST
@@ -250,7 +249,6 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
- putStrLn $ "withAccum: " ++ show t
accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
@@ -324,7 +322,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
@@ -372,13 +370,19 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
go mk dep' dim' idx' val' $ \arr pish ->
k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
newAcDense typ SZ () val = case typ of
+ STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> newAcSparse t1 SZ () x
Right y -> Right <$> newAcSparse t2 SZ () y
_ -> error "newAcDense: invalid dense type"
newAcDense typ (SS dep) idx val = case typ of
+ STPair t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
+ (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
+ _ -> error "Index/value mismatch in newAc pair"
STEither t1 t2 ->
case (idx, val) of
(Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
@@ -400,6 +404,7 @@ readAcSparse typ val = case typ of
readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
readAcDense typ val = case typ of
+ STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
STEither t1 t2 -> case val of
Left x -> Left <$> readAcSparse t1 x
Right y -> Right <$> readAcSparse t2 y
@@ -505,19 +510,27 @@ accumAddSparse typ (SS dep) ref idx val = case typ of
accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
accumAddDense typ SZ ref () val = case typ of
+ STPair t1 t2 ->
+ (ref, do accumAddSparse t1 SZ (fst ref) () (fst val)
+ accumAddSparse t2 SZ (snd ref) () (snd val))
STEither t1 t2 ->
case (ref, val) of
(Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val')
(Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val')
- _ -> error "Mismatched Either in accumAdd"
+ _ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
accumAddDense typ (SS dep) ref idx val = case typ of
+ STPair t1 t2 ->
+ case (idx, val) of
+ (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val')
+ (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val')
+ _ -> error "Mismatched Either in accumAddDense pair"
STEither t1 t2 ->
case (ref, idx, val) of
(Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val')
(Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
- _ -> error "Mismatched Either in accumAdd"
+ _ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 0007991..335ad1f 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -37,7 +37,7 @@ type family RepAcSparse t where
-- Immutable, and does not necessarily have a zero.
type family RepAcDense t where
RepAcDense TNil = ()
- -- RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
+ RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b)
-- RepAcDense (TMaybe t) = RepAcSparse (TMaybe t) -- ^ This can be optimised to TMaybe (RepAcSparse t), but that makes accumAddDense very hard to write. And in any case, we don't need it because D2 will not produce Maybe of Maybe.
-- RepAcDense (TArr n t) = Array n (RepAcSparse t)