diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 21:55:52 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 21:55:52 +0100 |
commit | 889aa1757a0fdf003f38f9d565a4a91660757f38 (patch) | |
tree | 7e142f72eabcee4af0d2d2fc58a7c18344797d74 | |
parent | 6fce8a75e239988d2ce154f5411dd2d8c742f3f6 (diff) |
Support EOneHot
-rw-r--r-- | src/AST.hs | 4 | ||||
-rw-r--r-- | src/Interpreter.hs | 55 |
2 files changed, 41 insertions, 18 deletions
@@ -103,7 +103,7 @@ data Expr x env t where -- monoidal operations (to be desugared to regular operations after simplification) EZero :: STy t -> Expr x env (D2 t) EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: STy t -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (D2 (AcVal t i)) -> Expr x env (D2 t) + EOneHot :: STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t) -- partiality EError :: STy a -> String -> Expr x env a @@ -120,7 +120,7 @@ type family Tup env where mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) -> SList f list -> f (Tup list) -mkTup nil _ SNil = nil +mkTup nil _ SNil = nil mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e tTup :: SList STy env -> STy (Tup env) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3eb8995..36543e9 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -22,6 +22,7 @@ module Interpreter ( import Control.Monad (foldM, join, when) import Data.Bifunctor (bimap) import Data.Char (isSpace) +import Data.Functor.Identity import Data.Kind (Type) import Data.Int (Int64) import Data.IORef @@ -140,7 +141,7 @@ interpret'Rec env = \case EOneHot t i a b -> do a' <- interpret' env a b' <- interpret' env b - return $ onehotD2 t i a' b' + return $ onehotD2 i t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s interpretOp :: SOp a t -> Rep a -> Rep t @@ -213,12 +214,24 @@ addD2s typ a b = case typ of STBool -> () STAccum{} -> error "Plus of Accum" -onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t) +onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t) onehotD2 SZ _ () v = v -onehotD2 (SS _) (STPair _ _ ) _ (Left () ) = Left () -onehotD2 (SS i) (STPair t1 t2) (Left idx) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" +onehotD2 _ STNil _ _ = () +onehotD2 (SS _ ) (STPair _ _ ) (Left _ ) (Left _ ) = Left () +onehotD2 (SS SZ ) (STPair _ _ ) (Right () ) (Right val ) = Right val +onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left idx)) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" +onehotD2 (SS _ ) (STEither _ _ ) (Left _ ) (Left _ ) = Left () +onehotD2 (SS SZ ) (STEither _ _ ) (Right () ) (Right val ) = Right val +onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left idx)) (Right (Left val)) = Right (Left (onehotD2 i t1 idx val)) +onehotD2 (SS (SS i)) (STEither _ t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val)) +onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value" +onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val) +onehotD2 (SS i ) (STArr n t) idx val = runIdentity $ + onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val +onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar" +onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do @@ -317,20 +330,30 @@ newAcSparse typ (SS dep) idx val = case typ of STEither{} -> error "Bare Either in accumulator" newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) -newAcArray _ t SZ _ val = - traverse (newAcSparse t SZ ()) val -newAcArray dim (t :: STy t) dep@SS{} idx val = do +newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n + +onehotArray :: Monad m + => STy t + -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -- ^ the "one" + -> m v -- ^ generate a zero value for elsewhere + -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v) +onehotArray _ mkone _ _ SZ _ val = + traverse (mkone SZ ()) val +onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do let sh = unTupRepIdx ShNil ShCons dim (fst val) - go dep dim idx (snd val) $ \arr position -> + go mkone dep dim idx (snd val) $ \arr position -> arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of Just i' -> return $ arr `arrayIndex` i' - Nothing -> newAcZero t) + Nothing -> mkzero) where - go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r - go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd - go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd - go (SS dep') (SS dim') (i, idx') val' k = - go dep' dim' idx' val' $ \arr pish -> + go :: Monad m + => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) + -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) + -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r + go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd + go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd + go mk (SS dep') (SS dim') (i, idx') val' k = + go mk dep' dim' idx' val' $ \arr pish -> k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) |