diff options
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r-- | src/Interpreter.hs | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs index da5b73c..3eb8995 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -137,6 +137,10 @@ interpret'Rec env = \case a' <- interpret' env a b' <- interpret' env b return $ addD2s t a' b' + EOneHot t i a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ onehotD2 t i a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -209,6 +213,13 @@ addD2s typ a b = case typ of STBool -> () STAccum{} -> error "Plus of Accum" +onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t) +onehotD2 SZ _ () v = v +onehotD2 (SS _) (STPair _ _ ) _ (Left () ) = Left () +onehotD2 (SS i) (STPair t1 t2) (Left idx) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" + withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do accum <- newAcSparse t SZ () initval |