diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-03 23:09:37 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-03 23:10:23 +0100 | 
| commit | 81d88dbc430ca6ec8390636f8b7162887b390873 (patch) | |
| tree | 849c126fad3b923c2e5b815aa5c8488907bc2318 | |
| parent | 2ca218d2e97e521bcc49dea8f4774737ba083ede (diff) | |
WIP map + zip
| -rw-r--r-- | src/AST.hs | 91 | ||||
| -rw-r--r-- | src/AST/Count.hs | 49 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 15 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 2 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 16 | ||||
| -rw-r--r-- | src/CHAD.hs | 29 | ||||
| -rw-r--r-- | src/Compile.hs | 19 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
| -rw-r--r-- | src/Interpreter.hs | 10 | ||||
| -rw-r--r-- | src/Simplify.hs | 8 | 
11 files changed, 214 insertions, 29 deletions
@@ -4,7 +4,9 @@  {-# LANGUAGE DeriveTraversable #-}  {-# LANGUAGE EmptyCase #-}  {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -16,7 +18,6 @@  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-}  module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where  import Data.Functor.Const @@ -62,6 +63,7 @@ data Expr x env t where    -- array operations    EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))    EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) +  EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)    -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)    EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)    ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) @@ -70,6 +72,7 @@ data Expr x env t where    EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))    EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))    EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) +  EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))    -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:    -- an implementation is allowed to parallelise this thing and store the b @@ -231,6 +234,7 @@ typeOf = \case    EConstArr _ n t _ -> STArr n (STScal t)    EBuild _ n _ e -> STArr n (typeOf e) +  EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)    EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t    ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EUnit _ e -> STArr SZ (typeOf e) @@ -238,6 +242,7 @@ typeOf = \case    EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t +  EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)    EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)    EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) @@ -282,6 +287,7 @@ extOf = \case    ELCase x _ _ _ _ -> x    EConstArr x _ _ _ -> x    EBuild x _ _ _ -> x +  EMap x _ _ -> x    EFold1Inner x _ _ _ _ -> x    ESum1Inner x _ -> x    EUnit x _ -> x @@ -289,6 +295,7 @@ extOf = \case    EMaximum1Inner x _ -> x    EMinimum1Inner x _ -> x    EReshape x _ _ _ -> x +  EZip x _ _ -> x    EFold1InnerD1 x _ _ _ _ -> x    EFold1InnerD2 x _ _ _ _ -> x    EConst x _ _ -> x @@ -331,12 +338,14 @@ travExt f = \case    ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c    EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a    EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b +  EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b    EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c    ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e    EUnit x e -> EUnit <$> f x <*> travExt f e    EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b    EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e    EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e +  EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b    EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b    EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c    EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c @@ -393,6 +402,7 @@ subst' f w = \case    ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)    EConstArr x n t a -> EConstArr x n t a    EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) +  EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)    EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    ESum1Inner x e -> ESum1Inner x (subst' f w e)    EUnit x e -> EUnit x (subst' f w e) @@ -400,6 +410,7 @@ subst' f w = \case    EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)    EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)    EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) +  EZip x a b -> EZip x (subst' f w a) (subst' f w b)    EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    EConst x t v -> EConst x t v @@ -521,34 +532,24 @@ eidxEq (SS n) a b  emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)  emap f arr -  | STArr n t <- typeOf arr +  | STArr _ t <- typeOf arr    , Dict <- styKnown t -  = ELet ext arr $ -      EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ -        ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) -                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -          weakenExpr (WCopy (WSink .> WSink)) f +  = EMap ext f arr  ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)  ezipWith f arr1 arr2 -  | STArr n t1 <- typeOf arr1 +  | STArr _ t1 <- typeOf arr1    , STArr _ t2 <- typeOf arr2    , Dict <- styKnown t1    , Dict <- styKnown t2 -  = ELet ext arr1 $ -    ELet ext (weakenExpr WSink arr2) $ -      EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ -        ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) -                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -        ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) -                           (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ -          weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +  = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) +                                   IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) +                                   IS (IS i) -> EVar ext t (IS i)) +                    f) +             (EZip ext arr1 arr2)  ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = -  let STArr _ t1 = typeOf arr1 -      STArr _ t2 = typeOf arr2 -  in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 +ezip = EZip ext  eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a  eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) @@ -652,3 +653,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t      go SMTMaybe{} _ = ENil ext      go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e      go SMTScal{} _ = ENil ext + +splitSparsePair +  :: -- given a sparsity +     STy (TPair a b) -> Sparse (TPair a b) t' +  -> (forall a' b'. +          -- I give you back two sparsities for a and b +          Sparse a a' -> Sparse b b' +          -- furthermore, I tell you that either your t' is already this (a', b') pair... +       -> Either +            (t' :~: TPair a' b') +            -- or I tell you how to construct a' and b' from t', given an actual t' +            (forall r' env. +                 Idx env t' +              -> (forall env'. +                     (forall c. Ex env' c -> Ex env c) +                  -> Ex env' a' -> Ex env' b' -> r') +              -> r') +       -> r) +  -> r +splitSparsePair _ SpAbsent k = +  k SpAbsent SpAbsent $ Right $ \_ k2 -> +  k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = +  k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = +  let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in +  k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> +  k2 (elet $ +       emaybe (EVar ext (STMaybe (applySparse s t)) i) +         (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) +         (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) +     (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = +  k SpAbsent SpAbsent $ Right $ \_ k2 -> +  k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = +  splitSparsePair t (SpSparse s) $ \s1 s2 eres -> +  k s1 s2 $ Right $ \i k2 -> +  case eres of +    Left refl -> case refl of {} +    Right f -> +      f IZ $ \wrap e1 e2 -> +        k2 (\body -> +              elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) +                     (ENothing ext (applySparse s t)) +                     (evar IZ)) $ +                wrap body) +           e1 e2 diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 296c021..bc02417 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of        SsNone ->          occCountX SsFull a $ \env1 mka ->          occCountX SsNone b $ \env2'' mkb -> -        withSome (scaleMany (Some env2'')) $ \env2' -> -        occEnvPop' env2' $ \env2 s2 -> -        withSome (Some env1 <> Some env2) $ \env -> +        occEnvPop' env2'' $ \env2' s2 -> +        withSome (Some env1 <> scaleMany (Some env2')) $ \env ->          k env $ \env' ->            use (EBuild ext n (mka env') $                   use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ @@ -535,14 +534,31 @@ occCountX initialS topexpr k = case topexpr of        SsArr' s' ->          occCountX SsFull a $ \env1 mka ->          occCountX s' b $ \env2'' mkb -> -        withSome (scaleMany (Some env2'')) $ \env2' -> -        occEnvPop' env2' $ \env2 s2 -> -        withSome (Some env1 <> Some env2) $ \env -> +        occEnvPop' env2'' $ \env2' s2 -> +        withSome (Some env1 <> scaleMany (Some env2')) $ \env ->          k env $ \env' ->            EBuild ext n (mka env') $              elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $                weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) +  EMap _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1'' mka -> +        occEnvPop' env1'' $ \env1' s1 -> +        occCountX (SsArr s1) b $ \env2 mkb -> +        withSome (scaleMany (Some env1') <> Some env2) $ \env -> +        k env $ \env' -> +          use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ +            ENil ext +      SsArr' s' -> +        occCountX s' a $ \env1'' mka -> +        occEnvPop' env1'' $ \env1' s1 -> +        occCountX (SsArr s1) b $ \env2 mkb -> +        withSome (scaleMany (Some env1') <> Some env2) $ \env -> +        k env $ \env' -> +          EMap ext (mka (OccPush env' () s1)) (mkb env') +    EFold1Inner _ commut a b c ->      occCountX SsFull a $ \env1''' mka ->      withSome (scaleMany (Some env1''')) $ \env1'' -> @@ -608,6 +624,27 @@ occCountX initialS topexpr k = case topexpr of          k env $ \env' ->            EReshape ext n (mkesh env') (mke env') +  EZip _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mka env') $ use (mkb env') $ ENil ext +      SsArr' SsNone -> +        occCountX (SsArr SsNone) a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mkb env') $ mka env' +      SsArr' (SsPair' s1 s2) -> +        occCountX (SsArr s1) a $ \env1 mka -> +        occCountX (SsArr s2) b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EZip ext (mka env') (mkb env') +    EFold1InnerD1 _ cm e1 e2 e3 ->      case s of        -- If nothing is necessary, we can execute a fold and then proceed to ignore it diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 68fc629..2c51b85 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -204,6 +204,14 @@ ppExpr' d val expr = case expr of             <> hardline <> e')          (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) +  EMap _ a b -> do +    let STArr _ t1 = typeOf b +    name <- genNameIfUsedIn' "i" t1 IZ a +    a' <- ppExpr' 0 (Const name `SCons` val) a +    b' <- ppExpr' 11 val b +    return $ ppParen (d > 0) $ +      ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] +    EFold1Inner _ cm a b c -> do      name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a      name2 <- genNameIfUsedIn (typeOf a) IZ a @@ -238,7 +246,12 @@ ppExpr' d val expr = case expr of    EReshape _ n esh e -> do      esh' <- ppExpr' 11 val esh      e' <- ppExpr' 11 val e -    return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' +    return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + +  EZip _ e1 e2 -> do +    e1' <- ppExpr' 11 val e1 +    e2' <- ppExpr' 11 val e2 +    return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']    EFold1InnerD1 _ cm a b c -> do      name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index a22b73f..1712ba5 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -38,6 +38,7 @@ unMonoid = \case    ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)    EConstArr _ n t x -> EConstArr ext n t x    EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) +  EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)    EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)    ESum1Inner _ e -> ESum1Inner ext (unMonoid e)    EUnit _ e -> EUnit ext (unMonoid e) @@ -45,6 +46,7 @@ unMonoid = \case    EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)    EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)    EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) +  EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)    EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)    EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)    EConst _ t x -> EConst ext t x diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 3a97fd1..f0820b8 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -129,7 +129,7 @@ wCopies bs w =    let bs' = slistMap (\_ -> Const ()) bs    in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env  wRaiseAbove SNil _ = WClosed  wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 6301dc1..71da793 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -202,6 +202,15 @@ idana env expr = case expr of      res <- VIArr <$> genId <*> shidsToVec dim shids      pure (res, EBuild res dim e1' e2') +  EMap _ e1 e2 -> do +    let STArr _ t = typeOf e2 +    x1 <- genIds t +    (_, e1') <- idana (x1 `SCons` env) e1 +    (v2, e2') <- idana env e2 +    let VIArr _ sh = v2 +    res <- VIArr <$> genId <*> pure sh +    pure (res, EMap res e1' e2') +    EFold1Inner _ cm e1 e2 e3 -> do      let t1 = typeOf e1      x1 <- genIds t1 @@ -250,6 +259,13 @@ idana env expr = case expr of      res <- VIArr <$> genId <*> shidsToVec dim v1      pure (res, EReshape res dim e1' e2') +  EZip _ e1 e2 -> do +    (v1, e1') <- idana env e1 +    (_, e2') <- idana env e2 +    let VIArr _ sh = v1 +    res <- VIArr <$> genId <*> pure sh +    pure (res, EZip res e1' e2') +    EFold1InnerD1 _ cm e1 e2 e3 -> do      let t1 = typeOf e2      x1 <- genIds t1 diff --git a/src/CHAD.hs b/src/CHAD.hs index 7594a0f..67ffe12 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1116,6 +1116,8 @@ drev des accumMap sd = \case                                   e2)      }} +  EMap{} -> undefined +    EFold1Inner _ commut origef ex₀ earr      | SpArr @_ @sdElt sdElt <- sd      , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr @@ -1346,6 +1348,33 @@ drev des accumMap sd = \case                                    (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $            weakenExpr (WCopy (WSink .> WSink)) e2) +  EZip _ a b +    | SpArr sd' <- sd +    , STArr n t1 <- typeOf a +    , STArr _ t2 <- typeOf b -> +    splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> +    case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` +                        toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of +    { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> +    subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> +    Ret binds +        subtape +        (EZip ext a1 b1) +        subBoth +        (case pairSplitE of +           Left Refl -> +             let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in +             plus_A_B +               (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) +               (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) +           Right f -> f IZ $ \wrapPair pick1 pick2 -> +             elet (emap (wrapPair (EPair ext pick1 pick2)) +                    (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ +             plus_A_B +               (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) +               (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) +    } +    ENothing{} -> err_unsupported "ENothing"    EJust{} -> err_unsupported "EJust"    EMaybe{} -> err_unsupported "EMaybe" diff --git a/src/Compile.hs b/src/Compile.hs index bf7817a..d6ad7ec 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -802,6 +802,15 @@ compile' env = \case      return (CELit arrname) +  -- TODO: actually generate decent code here +  EMap _ e1 e2 -> do +    let STArr n _ = typeOf e2 +    compile' env $ +      elet e2 $ +        EBuild ext n (EShape ext (evar IZ)) $ +          elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +            weakenExpr (WCopy (WSink .> WSink)) e1 +    EFold1Inner _ commut efun ex0 earr -> do      let STArr (SS n) t = typeOf earr @@ -951,6 +960,16 @@ compile' env = \case                [("buf", CEProj (CELit arrname) "buf")                ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) +  -- TODO: actually generate decent code here +  EZip _ e1 e2 -> do +    let STArr n _ = typeOf e1 +    compile' env $ +      elet e1 $ +      elet (weakenExpr WSink e2) $ +        EBuild ext n (EShape ext (evar (IS IZ))) $ +          EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) +                    (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) +    EFold1InnerD1 _ commut efun ex0 earr -> do      let STArr (SS n) t = typeOf earr          STPair _ bty = typeOf efun diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 44bdbb2..a1e9d0d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -153,6 +153,7 @@ dfwdDN = \case      (EConstArr ext n t x)    EBuild _ n a b      | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) +  EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b)    EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)    ESum1Inner _ e ->      let STArr n (STScal t) = typeOf e @@ -168,6 +169,7 @@ dfwdDN = \case    EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)    EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e    EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e +  EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b)    EReshape _ n esh e      | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)    EConst _ t x -> scalTyCase t diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 9e3d2a6..d982261 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -116,6 +116,9 @@ interpret'Rec env = \case    EBuild _ dim a b -> do      sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a      arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) +  EMap _ a b -> do +    let STArr _ t = typeOf b +    arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b    EFold1Inner _ _ a b c -> do      let t = typeOf b      let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a @@ -150,6 +153,13 @@ interpret'Rec env = \case      sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh      arr <- interpret' env e      return $ arrayReshape sh arr +  EZip _ a b -> do +    arr1 <- interpret' env a +    arr2 <- interpret' env b +    let sh = arrayShape arr1 +    when (sh /= arrayShape arr2) $ +      error "Interpreter: mismatched shapes in EZip" +    return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i))    EFold1InnerD1 _ _ a b c -> do      let t = typeOf b      let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a diff --git a/src/Simplify.hs b/src/Simplify.hs index b89d7f6..1889adc 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -200,12 +200,14 @@ simplify'Rec = \case        EMaybe ext (ESnd ext e1) (ESnd ext e2) e3    -- TODO: more array indexing -  EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet e3 e2 +  EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 +  EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1    EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)    EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1    -- TODO: more array shape    EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 +  EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2)    -- TODO: more constant folding    EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) @@ -308,6 +310,7 @@ simplify'Rec = \case    ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]    EConstArr _ n t v -> pure $ EConstArr ext n t v    EBuild _ n a b -> [simprec| EBuild ext n *a *b |] +  EMap _ a b -> [simprec| EMap ext *a *b |]    EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]    ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]    EUnit _ e -> [simprec| EUnit ext *e |] @@ -315,6 +318,7 @@ simplify'Rec = \case    EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]    EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]    EReshape _ n a b -> [simprec| EReshape ext n *a *b |] +  EZip _ a b -> [simprec| EZip ext *a *b |]    EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]    EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]    EConst _ t v -> pure $ EConst ext t v @@ -364,6 +368,7 @@ hasAdds = \case    ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c    EConstArr _ _ _ _ -> False    EBuild _ _ a b -> hasAdds a || hasAdds b +  EMap _ a b -> hasAdds a || hasAdds b    EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c    ESum1Inner _ e -> hasAdds e    EUnit _ e -> hasAdds e @@ -371,6 +376,7 @@ hasAdds = \case    EMaximum1Inner _ e -> hasAdds e    EMinimum1Inner _ e -> hasAdds e    EReshape _ _ a b -> hasAdds a || hasAdds b +  EZip _ a b -> hasAdds a || hasAdds b    EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c    EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c    ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e  | 
