diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 | 
| commit | ae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (patch) | |
| tree | 1f6afda4b1d6925fe8224ee4f2ca40212fe11aa6 /src/Interpreter.hs | |
| parent | 7774da51c532006da82617ce307d136897693280 (diff) | |
WIP accum top-level args
Diffstat (limited to 'src/Interpreter.hs')
| -rw-r--r-- | src/Interpreter.hs | 42 | 
1 files changed, 21 insertions, 21 deletions
| diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 37d4a83..56ebf82 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr)  import System.IO.Unsafe (unsafePerformIO)  import Debug.Trace +import GHC.Stack  import Array  import AST @@ -185,8 +186,8 @@ interpretOp op arg = case op of  zeroD2 :: STy t -> Rep (D2 t)  zeroD2 typ = case typ of    STNil -> () -  STPair _ _ -> Left () -  STEither _ _ -> Left () +  STPair _ _ -> Nothing +  STEither _ _ -> Nothing    STMaybe _ -> Nothing    STArr SZ t -> arrayUnit (zeroD2 t)    STArr n _ -> emptyArray n @@ -202,14 +203,14 @@ addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)  addD2s typ a b = case typ of    STNil -> ()    STPair t1 t2 -> case (a, b) of -    (Left (), _) -> b -    (_, Left ()) -> a -    (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) +    (Nothing, _) -> b +    (_, Nothing) -> a +    (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2)    STEither t1 t2 -> case (a, b) of -    (Left (), _) -> b -    (_, Left ()) -> a -    (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) -    (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) +    (Nothing, _) -> b +    (_, Nothing) -> a +    (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) +    (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y))      _ -> error "Plus of inconsistent Eithers"    STMaybe t -> case (a, b) of      (Nothing, _) -> b @@ -233,16 +234,14 @@ addD2s typ a b = case typ of  onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)  onehotD2 SZ _ () v = v  onehotD2 _ STNil _ _ = () -onehotD2 (SS _     ) (STPair _  _ ) (Left  _          ) (Left  _          ) = Left () -onehotD2 (SS SZ    ) (STPair _  _ ) (Right ()         ) (Right val        ) = Right val -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left  idx)) (Right (Left  val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _     ) (STPair _  _ ) _                   _                   = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS _     ) (STEither _  _ ) (Left  _          ) (Left  _          ) = Left () -onehotD2 (SS SZ    ) (STEither _  _ ) (Right ()         ) (Right val        ) = Right val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left  idx)) (Right (Left  val)) = Right (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _  t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _     ) (STEither _  _ ) _                   _                   = error "onehotD2: either: mismatched index and value" +onehotD2 (SS SZ    ) (STPair _  _ ) ()          val         = Just val +onehotD2 (SS (SS i)) (STPair t1 t2) (Left  idx) (Left  val) = Just (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _     ) (STPair _  _ ) _           _           = error "onehotD2: pair: mismatched index and value" +onehotD2 (SS SZ    ) (STEither _  _ ) ()          val         = Just val +onehotD2 (SS (SS i)) (STEither t1 _ ) (Left  idx) (Left  val) = Just (Left (onehotD2 i t1 idx val)) +onehotD2 (SS (SS i)) (STEither _  t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) +onehotD2 (SS _     ) (STEither _  _ ) _           _           = error "onehotD2: either: mismatched index and value"  onehotD2 (SS i     ) (STMaybe t) idx val = Just (onehotD2 i t idx val)  onehotD2 (SS i     ) (STArr n t) idx val = runIdentity $    onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val @@ -251,6 +250,7 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"  withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)  withAccum t _ initval f = AcM $ do +  putStrLn $ "withAccum: " ++ show t    accum <- newAcSparse t SZ () initval    out <- case f accum of AcM m -> m    val <- readAcSparse t accum @@ -324,7 +324,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n  piindexConcat PIIxEnd ix = ix  piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)  newAcSparse typ SZ () val = case typ of    STNil -> return ()    STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) @@ -372,7 +372,7 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do        go mk dep' dim' idx' val' $ \arr pish ->          k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)  newAcDense typ SZ () val = case typ of    STEither t1 t2 -> case val of      Left x -> Left <$> newAcSparse t1 SZ () x | 
