diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 |
commit | ae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (patch) | |
tree | 1f6afda4b1d6925fe8224ee4f2ca40212fe11aa6 /src/Interpreter.hs | |
parent | 7774da51c532006da82617ce307d136897693280 (diff) |
WIP accum top-level args
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 42 |
1 files changed, 21 insertions, 21 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 37d4a83..56ebf82 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace +import GHC.Stack import Array import AST @@ -185,8 +186,8 @@ interpretOp op arg = case op of zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of STNil -> () - STPair _ _ -> Left () - STEither _ _ -> Left () + STPair _ _ -> Nothing + STEither _ _ -> Nothing STMaybe _ -> Nothing STArr SZ t -> arrayUnit (zeroD2 t) STArr n _ -> emptyArray n @@ -202,14 +203,14 @@ addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) addD2s typ a b = case typ of STNil -> () STPair t1 t2 -> case (a, b) of - (Left (), _) -> b - (_, Left ()) -> a - (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) + (Nothing, _) -> b + (_, Nothing) -> a + (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) STEither t1 t2 -> case (a, b) of - (Left (), _) -> b - (_, Left ()) -> a - (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) - (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) + (Nothing, _) -> b + (_, Nothing) -> a + (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) _ -> error "Plus of inconsistent Eithers" STMaybe t -> case (a, b) of (Nothing, _) -> b @@ -233,16 +234,14 @@ addD2s typ a b = case typ of onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t) onehotD2 SZ _ () v = v onehotD2 _ STNil _ _ = () -onehotD2 (SS _ ) (STPair _ _ ) (Left _ ) (Left _ ) = Left () -onehotD2 (SS SZ ) (STPair _ _ ) (Right () ) (Right val ) = Right val -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left idx)) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS _ ) (STEither _ _ ) (Left _ ) (Left _ ) = Left () -onehotD2 (SS SZ ) (STEither _ _ ) (Right () ) (Right val ) = Right val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left idx)) (Right (Left val)) = Right (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _ t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value" +onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val +onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" +onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val +onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val)) +onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) +onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value" onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val) onehotD2 (SS i ) (STArr n t) idx val = runIdentity $ onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val @@ -251,6 +250,7 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do + putStrLn $ "withAccum: " ++ show t accum <- newAcSparse t SZ () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum @@ -324,7 +324,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n piindexConcat PIIxEnd ix = ix piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) @@ -372,7 +372,7 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do go mk dep' dim' idx' val' $ \arr pish -> k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) newAcDense typ SZ () val = case typ of STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 SZ () x |