summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Interpreter.hs')
-rw-r--r--src/Interpreter.hs11
1 files changed, 11 insertions, 0 deletions
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index da5b73c..3eb8995 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -137,6 +137,10 @@ interpret'Rec env = \case
a' <- interpret' env a
b' <- interpret' env b
return $ addD2s t a' b'
+ EOneHot t i a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ onehotD2 t i a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -209,6 +213,13 @@ addD2s typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
+onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t)
+onehotD2 SZ _ () v = v
+onehotD2 (SS _) (STPair _ _ ) _ (Left () ) = Left ()
+onehotD2 (SS i) (STPair t1 t2) (Left idx) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2)
+onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val)
+onehotD2 (SS _) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
+
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
accum <- newAcSparse t SZ () initval