diff options
-rw-r--r-- | src/AST.hs | 3 | ||||
-rw-r--r-- | src/AST/Count.hs | 1 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 6 | ||||
-rw-r--r-- | src/CHAD.hs | 1 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 1 | ||||
-rw-r--r-- | src/Interpreter.hs | 11 | ||||
-rw-r--r-- | src/Simplify.hs | 2 |
7 files changed, 25 insertions, 0 deletions
@@ -103,6 +103,7 @@ data Expr x env t where -- monoidal operations (to be desugared to regular operations after simplification) EZero :: STy t -> Expr x env (D2 t) EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) + EOneHot :: STy t -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (D2 (AcVal t i)) -> Expr x env (D2 t) -- partiality EError :: STy a -> String -> Expr x env a @@ -206,6 +207,7 @@ typeOf = \case EZero t -> d2 t EPlus t _ _ -> d2 t + EOneHot t _ _ _ -> d2 t EError t _ -> t @@ -277,6 +279,7 @@ subst' f w = \case EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero t -> EZero t EPlus t a b -> EPlus t (subst' f w a) (subst' f w b) + EOneHot t i a b -> EOneHot t i (subst' f w a) (subst' f w b) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index a928743..f3e3d74 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -130,6 +130,7 @@ occCountGeneral onehot unpush alter many = go WId EAccum _ a b e -> re a <> re b <> re e EZero _ -> mempty EPlus _ a b -> re a <> re b + EOneHot _ _ a b -> re a <> re b EError{} -> mempty where re :: Monoid (r env') => Expr x env' t'' -> r env' diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index a05b49e..677c767 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -204,6 +204,12 @@ ppExpr' d val = \case b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' + EOneHot _ i a b -> do + a' <- ppExpr' 11 val a + b' <- ppExpr' 11 val b + return $ showParen (d > 10) $ + showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b' + EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS diff --git a/src/CHAD.hs b/src/CHAD.hs index d45898a..b3e2358 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1198,6 +1198,7 @@ drev des = \case EAccum{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid + EOneHot{} -> err_monoid where err_accum = error "Accumulator operations unsupported in the source program" diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index beb93da..8b4acb3 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -186,6 +186,7 @@ dfwdDN = \case EAccum{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid + EOneHot{} -> err_monoid where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index da5b73c..3eb8995 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -137,6 +137,10 @@ interpret'Rec env = \case a' <- interpret' env a b' <- interpret' env b return $ addD2s t a' b' + EOneHot t i a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ onehotD2 t i a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -209,6 +213,13 @@ addD2s typ a b = case typ of STBool -> () STAccum{} -> error "Plus of Accum" +onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t) +onehotD2 SZ _ () v = v +onehotD2 (SS _) (STPair _ _ ) _ (Left () ) = Left () +onehotD2 (SS i) (STPair t1 t2) (Left idx) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" + withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do accum <- newAcSparse t SZ () initval diff --git a/src/Simplify.hs b/src/Simplify.hs index cfbdbb9..0ce5594 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -111,6 +111,7 @@ simplify' = \case EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 EZero t -> pure $ EZero t EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b + EOneHot t i a b -> EOneHot t i <$> simplify' a <*> simplify' b EError t s -> pure $ EError t s acted :: (Any, a) -> (Any, a) @@ -156,6 +157,7 @@ hasAdds = \case EAccum _ _ _ _ -> True EZero _ -> False EPlus _ a b -> hasAdds a || hasAdds b + EOneHot _ _ a b -> hasAdds a || hasAdds b EError _ _ -> False checkAccumInScope :: SList STy env -> Bool |