diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 21:55:52 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-05 21:55:52 +0100 | 
| commit | 889aa1757a0fdf003f38f9d565a4a91660757f38 (patch) | |
| tree | 7e142f72eabcee4af0d2d2fc58a7c18344797d74 | |
| parent | 6fce8a75e239988d2ce154f5411dd2d8c742f3f6 (diff) | |
Support EOneHot
| -rw-r--r-- | src/AST.hs | 4 | ||||
| -rw-r--r-- | src/Interpreter.hs | 55 | 
2 files changed, 41 insertions, 18 deletions
| @@ -103,7 +103,7 @@ data Expr x env t where    -- monoidal operations (to be desugared to regular operations after simplification)    EZero :: STy t -> Expr x env (D2 t)    EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) -  EOneHot :: STy t -> SNat i -> Expr x env (AcIdx t i) -> Expr x env (D2 (AcVal t i)) -> Expr x env (D2 t) +  EOneHot :: STy t -> SNat i -> Expr x env (AcIdx (D2 t) i) -> Expr x env (AcVal (D2 t) i) -> Expr x env (D2 t)    -- partiality    EError :: STy a -> String -> Expr x env a @@ -120,7 +120,7 @@ type family Tup env where  mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))        -> SList f list -> f (Tup list) -mkTup nil _     SNil = nil +mkTup nil _    SNil = nil  mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e  tTup :: SList STy env -> STy (Tup env) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3eb8995..36543e9 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -22,6 +22,7 @@ module Interpreter (  import Control.Monad (foldM, join, when)  import Data.Bifunctor (bimap)  import Data.Char (isSpace) +import Data.Functor.Identity  import Data.Kind (Type)  import Data.Int (Int64)  import Data.IORef @@ -140,7 +141,7 @@ interpret'Rec env = \case    EOneHot t i a b -> do      a' <- interpret' env a      b' <- interpret' env b -    return $ onehotD2 t i a' b' +    return $ onehotD2 i t a' b'    EError _ s -> error $ "Interpreter: Program threw error: " ++ s  interpretOp :: SOp a t -> Rep a -> Rep t @@ -213,12 +214,24 @@ addD2s typ a b = case typ of      STBool -> ()    STAccum{} -> error "Plus of Accum" -onehotD2 :: SNat i -> STy t -> Rep (AcIdx t i) -> Rep (D2 (AcVal t i)) -> Rep (D2 t) +onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)  onehotD2 SZ _ () v = v -onehotD2 (SS _) (STPair _  _ ) _           (Left  ()         ) = Left () -onehotD2 (SS i) (STPair t1 t2) (Left  idx) (Right (Left  val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS i) (STPair t1 t2) (Right idx) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _) (STPair _  _ ) _           _                   = error "onehotD2: pair: mismatched index and value" +onehotD2 _ STNil _ _ = () +onehotD2 (SS _     ) (STPair _  _ ) (Left  _          ) (Left  _          ) = Left () +onehotD2 (SS SZ    ) (STPair _  _ ) (Right ()         ) (Right val        ) = Right val +onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left  idx)) (Right (Left  val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _     ) (STPair _  _ ) _                   _                   = error "onehotD2: pair: mismatched index and value" +onehotD2 (SS _     ) (STEither _  _ ) (Left  _          ) (Left  _          ) = Left () +onehotD2 (SS SZ    ) (STEither _  _ ) (Right ()         ) (Right val        ) = Right val +onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left  idx)) (Right (Left  val)) = Right (Left (onehotD2 i t1 idx val)) +onehotD2 (SS (SS i)) (STEither _  t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val)) +onehotD2 (SS _     ) (STEither _  _ ) _                   _                   = error "onehotD2: either: mismatched index and value" +onehotD2 (SS i     ) (STMaybe t) idx val = Just (onehotD2 i t idx val) +onehotD2 (SS i     ) (STArr n t) idx val = runIdentity $ +  onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val +onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar" +onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"  withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)  withAccum t _ initval f = AcM $ do @@ -317,20 +330,30 @@ newAcSparse typ (SS dep) idx val = case typ of    STEither{} -> error "Bare Either in accumulator"  newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) -newAcArray _ t SZ _ val = -  traverse (newAcSparse t SZ ()) val -newAcArray dim (t :: STy t) dep@SS{} idx val = do +newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n + +onehotArray :: Monad m +            => STy t +            -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v)  -- ^ the "one" +            -> m v  -- ^ generate a zero value for elsewhere +            -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v) +onehotArray _ mkone _ _ SZ _ val = +  traverse (mkone SZ ()) val +onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do    let sh = unTupRepIdx ShNil ShCons dim (fst val) -  go dep dim idx (snd val) $ \arr position -> +  go mkone dep dim idx (snd val) $ \arr position ->      arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of                                 Just i' -> return $ arr `arrayIndex` i' -                               Nothing -> newAcZero t) +                               Nothing -> mkzero)    where -    go :: SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -> (forall m. Array m (RepAcSparse t) -> PartialInvIndex n m -> IO r) -> IO r -    go SZ _ () val' k = arrayMapM (newAcSparse t SZ ()) val' >>= \arr -> k arr PIIxEnd -    go (SS dep') SZ idx' val' k = newAcSparse t dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd -    go (SS dep') (SS dim') (i, idx') val' k = -      go dep' dim' idx' val' $ \arr pish -> +    go :: Monad m +       => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) +       -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) +       -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r +    go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd +    go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd +    go mk (SS dep') (SS dim') (i, idx') val' k = +      go mk dep' dim' idx' val' $ \arr pish ->          k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)  newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) | 
