summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-05 21:55:52 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-05 21:55:52 +0100
commit889aa1757a0fdf003f38f9d565a4a91660757f38 (patch)
tree7e142f72eabcee4af0d2d2fc58a7c18344797d74
parent6fce8a75e239988d2ce154f5411dd2d8c742f3f6 (diff)
Support EOneHot
-rw-r--r--src/AST.hs4
-rw-r--r--src/Interpreter.hs55
2 files changed, 41 insertions, 18 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 328a670..b9b10ad 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -103,7 +103,7 @@ data Expr x env t where
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: STy t -> Expr x env (D2 t)
EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
- EOneHot :: STy t -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (D2 (AcVal t i)) -> Expr x env (D2 t)
+ EOneHot :: STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t)
-- partiality
EError :: STy a -> String -> Expr x env a
@@ -120,7 +120,7 @@ type family Tup env where
mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
-> SList f list -> f (Tup list)
-mkTup nil _ SNil = nil
+mkTup nil _ SNil = nil
mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
tTup :: SList STy env -> STy (Tup env)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 3eb8995..36543e9 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -22,6 +22,7 @@ module Interpreter (
import Control.Monad (foldM, join, when)
import Data.Bifunctor (bimap)
import Data.Char (isSpace)
+import Data.Functor.Identity
import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
@@ -140,7 +141,7 @@ interpret'Rec env = \case
EOneHot t i a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ onehotD2 t i a' b'
+ return $ onehotD2 i t a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -213,12 +214,24 @@ addD2s typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
-onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t)
+onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)
onehotD2 SZ _ () v = v
-onehotD2 (SS _) (STPair _ _ ) _ (Left () ) = Left ()
-onehotD2 (SS i) (STPair t1 t2) (Left idx) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2)
-onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val)
-onehotD2 (SS _) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
+onehotD2 _ STNil _ _ = ()
+onehotD2 (SS _ ) (STPair _ _ ) (Left _ ) (Left _ ) = Left ()
+onehotD2 (SS SZ ) (STPair _ _ ) (Right () ) (Right val ) = Right val
+onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left idx)) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2)
+onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val)
+onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
+onehotD2 (SS _ ) (STEither _ _ ) (Left _ ) (Left _ ) = Left ()
+onehotD2 (SS SZ ) (STEither _ _ ) (Right () ) (Right val ) = Right val
+onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left idx)) (Right (Left val)) = Right (Left (onehotD2 i t1 idx val))
+onehotD2 (SS (SS i)) (STEither _ t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val))
+onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
+onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val)
+onehotD2 (SS i ) (STArr n t) idx val = runIdentity $
+ onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val
+onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar"
+onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
@@ -317,20 +330,30 @@ newAcSparse typ (SS dep) idx val = case typ of
STEither{} -> error "Bare Either in accumulator"
newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t))
-newAcArray _ t SZ _ val =
- traverse (newAcSparse t SZ ()) val
-newAcArray dim (t :: STy t) dep@SS{} idx val = do
+newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n
+
+onehotArray :: Monad m
+ => STy t
+ -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -- ^ the "one"
+ -> m v -- ^ generate a zero value for elsewhere
+ -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v)
+onehotArray _ mkone _ _ SZ _ val =
+ traverse (mkone SZ ()) val
+onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
let sh = unTupRepIdx ShNil ShCons dim (fst val)
- go dep dim idx (snd val) $ \arr position ->
+ go mkone dep dim idx (snd val) $ \arr position ->
arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
Just i' -> return $ arr `arrayIndex` i'
- Nothing -> newAcZero t)
+ Nothing -> mkzero)
where
- go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r
- go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd
- go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
- go (SS dep') (SS dim') (i, idx') val' k =
- go dep' dim' idx' val' $ \arr pish ->
+ go :: Monad m
+ => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v)
+ -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
+ -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r
+ go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd
+ go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
+ go mk (SS dep') (SS dim') (i, idx') val' k =
+ go mk dep' dim' idx' val' $ \arr pish ->
k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)