summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-04 23:33:34 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-04 23:33:34 +0100
commit6fce8a75e239988d2ce154f5411dd2d8c742f3f6 (patch)
tree2edd579d69ab9168c10965a86135daf807f127a4 /src
parent4e41364e73a2fbb902e41281c59991b6c789723f (diff)
WIP EOneHot
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs3
-rw-r--r--src/AST/Count.hs1
-rw-r--r--src/AST/Pretty.hs6
-rw-r--r--src/CHAD.hs1
-rw-r--r--src/ForwardAD/DualNumbers.hs1
-rw-r--r--src/Interpreter.hs11
-rw-r--r--src/Simplify.hs2
7 files changed, 25 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 71001e7..328a670 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -103,6 +103,7 @@ data Expr x env t where
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: STy t -> Expr x env (D2 t)
EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
+ EOneHot :: STy t -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (D2 (AcVal t i)) -> Expr x env (D2 t)
-- partiality
EError :: STy a -> String -> Expr x env a
@@ -206,6 +207,7 @@ typeOf = \case
EZero t -> d2 t
EPlus t _ _ -> d2 t
+ EOneHot t _ _ _ -> d2 t
EError t _ -> t
@@ -277,6 +279,7 @@ subst' f w = \case
EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
EZero t -> EZero t
EPlus t a b -> EPlus t (subst' f w a) (subst' f w b)
+ EOneHot t i a b -> EOneHot t i (subst' f w a) (subst' f w b)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index a928743..f3e3d74 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -130,6 +130,7 @@ occCountGeneral onehot unpush alter many = go WId
EAccum _ a b e -> re a <> re b <> re e
EZero _ -> mempty
EPlus _ a b -> re a <> re b
+ EOneHot _ _ a b -> re a <> re b
EError{} -> mempty
where
re :: Monoid (r env') => Expr x env' t'' -> r env'
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index a05b49e..677c767 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -204,6 +204,12 @@ ppExpr' d val = \case
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
+ EOneHot _ i a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ showParen (d > 10) $
+ showString ("onehot " ++ show (fromSNat i) ++ " ") . a' . showString " " . b'
+
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
diff --git a/src/CHAD.hs b/src/CHAD.hs
index d45898a..b3e2358 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1198,6 +1198,7 @@ drev des = \case
EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
where
err_accum = error "Accumulator operations unsupported in the source program"
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index beb93da..8b4acb3 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -186,6 +186,7 @@ dfwdDN = \case
EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index da5b73c..3eb8995 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -137,6 +137,10 @@ interpret'Rec env = \case
a' <- interpret' env a
b' <- interpret' env b
return $ addD2s t a' b'
+ EOneHot t i a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ onehotD2 t i a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -209,6 +213,13 @@ addD2s typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
+onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t)
+onehotD2 SZ _ () v = v
+onehotD2 (SS _) (STPair _ _ ) _ (Left () ) = Left ()
+onehotD2 (SS i) (STPair t1 t2) (Left idx) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2)
+onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val)
+onehotD2 (SS _) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
+
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
accum <- newAcSparse t SZ () initval
diff --git a/src/Simplify.hs b/src/Simplify.hs
index cfbdbb9..0ce5594 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -111,6 +111,7 @@ simplify' = \case
EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
EZero t -> pure $ EZero t
EPlus t a b -> EPlus t <$> simplify' a <*> simplify' b
+ EOneHot t i a b -> EOneHot t i <$> simplify' a <*> simplify' b
EError t s -> pure $ EError t s
acted :: (Any, a) -> (Any, a)
@@ -156,6 +157,7 @@ hasAdds = \case
EAccum _ _ _ _ -> True
EZero _ -> False
EPlus _ a b -> hasAdds a || hasAdds b
+ EOneHot _ _ a b -> hasAdds a || hasAdds b
EError _ _ -> False
checkAccumInScope :: SList STy env -> Bool