summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-26 15:25:13 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-26 15:25:13 +0100
commitae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (patch)
tree1f6afda4b1d6925fe8224ee4f2ca40212fe11aa6 /src/Interpreter.hs
parent7774da51c532006da82617ce307d136897693280 (diff)
WIP accum top-level args
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs42
1 files changed, 21 insertions, 21 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 37d4a83..56ebf82 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
+import GHC.Stack
import Array
import AST
@@ -185,8 +186,8 @@ interpretOp op arg = case op of
zeroD2 :: STy t -> Rep (D2 t)
zeroD2 typ = case typ of
STNil -> ()
- STPair _ _ -> Left ()
- STEither _ _ -> Left ()
+ STPair _ _ -> Nothing
+ STEither _ _ -> Nothing
STMaybe _ -> Nothing
STArr SZ t -> arrayUnit (zeroD2 t)
STArr n _ -> emptyArray n
@@ -202,14 +203,14 @@ addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
addD2s typ a b = case typ of
STNil -> ()
STPair t1 t2 -> case (a, b) of
- (Left (), _) -> b
- (_, Left ()) -> a
- (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2)
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2)
STEither t1 t2 -> case (a, b) of
- (Left (), _) -> b
- (_, Left ()) -> a
- (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y))
- (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y))
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y))
+ (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y))
_ -> error "Plus of inconsistent Eithers"
STMaybe t -> case (a, b) of
(Nothing, _) -> b
@@ -233,16 +234,14 @@ addD2s typ a b = case typ of
onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)
onehotD2 SZ _ () v = v
onehotD2 _ STNil _ _ = ()
-onehotD2 (SS _ ) (STPair _ _ ) (Left _ ) (Left _ ) = Left ()
-onehotD2 (SS SZ ) (STPair _ _ ) (Right () ) (Right val ) = Right val
-onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left idx)) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2)
-onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val)
-onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
-onehotD2 (SS _ ) (STEither _ _ ) (Left _ ) (Left _ ) = Left ()
-onehotD2 (SS SZ ) (STEither _ _ ) (Right () ) (Right val ) = Right val
-onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left idx)) (Right (Left val)) = Right (Left (onehotD2 i t1 idx val))
-onehotD2 (SS (SS i)) (STEither _ t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val))
-onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
+onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val
+onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2)
+onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val)
+onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
+onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val
+onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val))
+onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val))
+onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val)
onehotD2 (SS i ) (STArr n t) idx val = runIdentity $
onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val
@@ -251,6 +250,7 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
+ putStrLn $ "withAccum: " ++ show t
accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
@@ -324,7 +324,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
@@ -372,7 +372,7 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
go mk dep' dim' idx' val' $ \arr pish ->
k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
newAcDense typ SZ () val = case typ of
STEither t1 t2 -> case val of
Left x -> Left <$> newAcSparse t1 SZ () x