diff options
| -rw-r--r-- | src/AST.hs | 6 | ||||
| -rw-r--r-- | src/AST/Accum.hs | 33 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 29 | ||||
| -rw-r--r-- | src/AST/Types.hs | 2 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 119 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 8 | ||||
| -rw-r--r-- | src/CHAD/Accum.hs | 2 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 4 | ||||
| -rw-r--r-- | src/Interpreter.hs | 261 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 35 | ||||
| -rw-r--r-- | src/Simplify.hs | 57 | 
11 files changed, 286 insertions, 270 deletions
| @@ -88,8 +88,8 @@ data Expr x env t where            -> Expr x env t    -- accumulation effect on monoids -  EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum (D2 t) : env) a -> Expr x env (TPair a (D2 t)) -  EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum (D2 a)) -> Expr x env TNil +  EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) +  EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil    -- monoidal operations (to be desugared to regular operations after simplification)    EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) @@ -381,6 +381,8 @@ ebuildUp1 n sh size f =  eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)  eidxEq SZ _ _ = EConst ext STBool True +eidxEq (SS SZ) a b = +  EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b))  eidxEq (SS n) a b    | let ty = tTup (sreplicate (SS n) tIx)    = ELet ext a $ diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 163f1c3..6c46ad5 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -21,16 +21,17 @@ data AcPrj    | APArrIdx AcPrj    | APArrSlice Nat --- | @b@ is a small part of @a@, indicated by the projection. +-- | @b@ is a small part of @a@, indicated by the projection @p@.  data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where    SAPHere :: SAcPrj APHere a a -  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b -  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b -  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b +  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b +  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b +  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b    SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b -  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b -  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b -  SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b +  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b +  -- TODO: +  -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)  deriving instance Show (SAcPrj p a b)  type family AcIdx p t where @@ -40,5 +41,19 @@ type family AcIdx p t where    AcIdx (APLeft p) (TEither a b) = AcIdx p a    AcIdx (APRight p) (TEither a b) = AcIdx p b    AcIdx (APJust p) (TMaybe a) = AcIdx p a -  AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) -  AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) +  AcIdx (APArrIdx p) (TArr n a) = +    -- ((index, array shape), recursive info) +    TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) +          (AcIdx p a) +  -- AcIdx (APArrSlice m) (TArr n a) = +  --   -- (index, array shape) +  --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) + +acPrjTy :: SAcPrj p a b -> STy a -> STy b +acPrjTy SAPHere t = t +acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index f91aff2..b9406d7 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -235,9 +235,9 @@ ppExpr' d val expr = case expr of          ,e1'          ,e2'] -  EWith _ e1 e2 -> do +  EWith _ t e1 e2 -> do      e1' <- ppExpr' 11 val e1 -    name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 +    name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2      e2' <- ppExpr' 0 (Const name `SCons` val) e2      return $ ppParen (d > 0) $        group $ flatAlt @@ -247,12 +247,12 @@ ppExpr' d val expr = case expr of             <> hardline <> e2')          (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) -  EAccum _ i e1 e2 e3 -> do +  EAccum _ _ prj e1 e2 e3 -> do      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2      e3' <- ppExpr' 11 val e3      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] +      ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']    EZero _ t -> return $ ppParen (d > 0) $      annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t @@ -263,11 +263,11 @@ ppExpr' d val expr = case expr of      return $ ppParen (d > 10) $        ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] -  EOneHot _ _ i a b -> do +  EOneHot _ _ prj a b -> do      a' <- ppExpr' 11 val a      b' <- ppExpr' 11 val b      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b'] +      ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']    EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -300,6 +300,15 @@ ppLam :: [ADoc] -> ADoc -> ADoc  ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])                                            <> softline <> body <> ppString ")") +ppAcPrj :: SAcPrj p a b -> String +ppAcPrj SAPHere = "@" +ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" +ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" +ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" +ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" +ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj +ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +  ppX :: PrettyX x => Expr x env t -> ADoc  ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -355,6 +364,14 @@ ppParen :: Bool -> Doc x -> Doc x  ppParen True = parens  ppParen False = id +intSubscript :: Int -> String +intSubscript = \case 0 -> "₀" +                     n | n < 0 -> '₋' : go (-n) "" +                       | otherwise -> go n "" +  where go 0 suff = suff +        go n suff = let (q, r) = n `quotRem` 10 +                    in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff) +  data Annot = AKey | AWith | AHighlight | AMonoid | AExt    deriving (Show) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index be7cffe..0b41671 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -21,7 +21,7 @@ data Ty    | TMaybe Ty    | TArr Nat Ty  -- ^ rank, element type    | TScal ScalTy -  | TAccum Ty +  | TAccum Ty  -- ^ the accumulator contains D2 of this type    deriving (Show, Eq, Ord)  data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ec5e11e..ae9728a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -2,7 +2,7 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid where +module AST.UnMonoid (unMonoid, zero, plus) where  import AST  import CHAD.Types @@ -117,110 +117,21 @@ plusSparse t a b adder =        (weakenExpr WSink a)  onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) -onehot _ topprj arg = case topprj of -  SAPHere -> arg +onehot typ topprj idx arg = case (typ, topprj) of +  (_, SAPHere) -> arg -  SAPFst prj -> _ +  (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) +  (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) -onehot t (SS dep) arg = case t of -  STPair t1 t2 -> -    case dep of -      SZ -> EJust ext val -      SS dep' -> -        let STEither tidx1 tidx2 = typeOf idx -            STEither tval1 tval2 = typeOf val -        in EJust ext $ -            ECase ext idx -              (ECase ext (weakenExpr WSink val) -                (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) -                           (zero t2)) -                (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) -              (ECase ext (weakenExpr WSink val) -                (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l") -                (EPair ext (zero t1) -                           (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) +  (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) +  (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) -  STEither t1 t2 -> -    case dep of -      SZ -> EJust ext val -      SS dep' -> -        let STEither tidx1 tidx2 = typeOf idx -            STEither tval1 tval2 = typeOf val -        in EJust ext $ -            ECase ext idx -              (ECase ext (weakenExpr WSink val) -                (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) -                (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r")) -              (ECase ext (weakenExpr WSink val) -                (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l") -                (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) +  (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) -  STMaybe t1 -> EJust ext (onehot t1 dep idx val) - -  STArr n t1 -> -    ELet ext val $ -      EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) -                   (onehotArrayElem t1 n (SS dep) -                                    (EVar ext (tTup (sreplicate n tIx)) IZ) -                                    (weakenExpr (WSink .> WSink) idx) -                                    (ESnd ext (EVar ext (typeOf val) (IS IZ)))) - -  STNil -> error "Cannot index into nil" -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Accumulators not allowed in input program" - --- onehotArrayElem ---   :: STy t -> SNat n -> SNat i ---   -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' ---   -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot ---   -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole ---   -> Ex env (D2 t) --- onehotArrayElem t n dep eltidx idx val = ---   ELet ext eltidx $ ---   ELet ext (weakenExpr WSink idx) $ ---     let (cond, elt) = onehotArrayElemRec t n dep ---                                          (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) ---                                          (EVar ext (typeOf idx) IZ) ---                                          (weakenExpr (WSink .> WSink) val) ---     in eif cond elt (zero t) - --- -- AcIdx must be duplicable --- onehotArrayElemRec ---   :: STy t -> SNat n -> SNat i ---   -> [Ex env TIx] ---   -> Ex env (AcIdx (TArr n (D2 t)) i) ---   -> Ex env (AcValArr n (D2 t) i) ---   -> (Ex env (TScal TBool), Ex env (D2 t)) --- onehotArrayElemRec _ n SZ eltidx _ val = ---   (EConst ext STBool True ---   ,EIdx ext val (reconstructFromOutsideIn n eltidx)) --- onehotArrayElemRec t SZ (SS dep) eltidx idx val = ---   case eltidx of ---     [] -> (EConst ext STBool True, onehot t dep idx val) ---     _ -> error "onehotArrayElemRec: mismatched list length" --- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = ---   case eltidx of ---     i : eltidx' -> ---       let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val ---       in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) ---          ,elt) ---     [] -> error "onehotArrayElemRec: mismatched list length" - --- | Outermost index at the head. The input expression must be duplicable. -outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -outsideInIndex = \n idx -> go n idx [] -  where -    go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] -    go SZ _ acc = acc -    go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) - --- Takes a list with the outermost index at the head. Returns a tuple with the --- innermost index on the right. -reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) -reconstructFromOutsideIn = \n list -> go n (reverse list) -  where -    -- Takes list with the _innermost_ index at the head. -    go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) -    go SZ [] = ENil ext -    go (SS n) (i:is) = EPair ext (go n is) i -    go _ _ = error "reconstructFromOutsideIn: mismatched list length" +  (STArr n t1, SAPArrIdx prj _) -> +    let tidx = tTup (sreplicate n tIx) +    in ELet ext idx $ +         EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ +           eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) +             (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) +             (zero t1) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 020ca34..5e36dde 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -237,20 +237,20 @@ idana env expr = case expr of      res <- genIds t4      pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') -  EWith _ e1 e2 -> do +  EWith _ t e1 e2 -> do      let t1 = typeOf e1      (_, e1') <- idana env e1      x1 <- VIAccum <$> genId      (v2, e2') <- idana (x1 `SCons` env) e2      x2 <- genIds t1      let res = VIPair v2 x2 -    pure (res, EWith res e1' e2') +    pure (res, EWith res t e1' e2') -  EAccum _ i e1 e2 e3 -> do +  EAccum _ t prj e1 e2 e3 -> do      (_, e1') <- idana env e1      (_, e2') <- idana env e2      (_, e3') <- idana env e3 -    pure (VINil, EAccum VINil i e1' e2' e3') +    pure (VINil, EAccum VINil t prj e1' e2' e3')    EZero _ t -> do      res <- genIds (d2 t) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index 14a1d3b..b61b5ff 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -12,7 +12,7 @@ makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex e  makeAccumulators SNil e = e  makeAccumulators (t `SCons` envpro) e =    makeAccumulators envpro $ -    EWith ext (EZero ext t) e +    EWith ext t (EZero ext t) e  uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))  uninvertTup SNil _ e = EPair ext e (ENil ext) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index a8614cf..e8ec0c9 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -40,7 +40,7 @@ type family D2E env where  type family D2AcE env where    D2AcE '[] = '[] -  D2AcE (t : env) = TAccum (D2 t) : D2AcE env +  D2AcE (t : env) = TAccum t : D2AcE env  d1 :: STy t -> STy (D1 t)  d1 STNil = STNil @@ -75,7 +75,7 @@ d2e (t `SCons` ts) = d2 t `SCons` d2e ts  d2ace :: SList STy env -> SList STy (D2AcE env)  d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts +d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts  data CHADConfig = CHADConfig diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d80a76e..11caac0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -21,6 +21,7 @@ module Interpreter (  import Control.Monad (foldM, join, when)  import Data.Bifunctor (bimap) +import Data.Bitraversable (bitraverse)  import Data.Char (isSpace)  import Data.Functor.Identity  import Data.Kind (Type) @@ -134,26 +135,25 @@ interpret'Rec env = \case      e1' <- interpret' env e1      e2' <- interpret' env e2      interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr -  EWith _ e1 e2 -> do +  EWith _ t e1 e2 -> do      initval <- interpret' env e1 -    withAccum (typeOf e1) (typeOf e2) initval $ \accum -> +    withAccum t (typeOf e2) initval $ \accum ->        interpret' (Value accum `SCons` env) e2 -  EAccum _ i e1 e2 e3 -> do -    let STAccum t = typeOf e3 +  EAccum _ t p e1 e2 e3 -> do      idx <- interpret' env e1      val <- interpret' env e2      accum <- interpret' env e3 -    accumAddSparse t i accum idx val +    accumAddSparse t p accum idx val    EZero _ t -> do      return $ zeroD2 t    EPlus _ t a b -> do      a' <- interpret' env a      b' <- interpret' env b      return $ addD2s t a' b' -  EOneHot _ t i a b -> do +  EOneHot _ t p a b -> do      a' <- interpret' env a      b' <- interpret' env b -    return $ onehotD2 i t a' b' +    return $ onehotD2 p t a' b'    EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s  interpretOp :: SOp a t -> Rep a -> Rep t @@ -230,44 +230,37 @@ addD2s typ a b = case typ of      STBool -> ()    STAccum{} -> error "Plus of Accum" -onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t) -onehotD2 SZ _ () v = v -onehotD2 _ STNil _ _ = () -onehotD2 (SS SZ    ) (STPair _  _ ) ()          val         = Just val -onehotD2 (SS (SS i)) (STPair t1 t2) (Left  idx) (Left  val) = Just (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _     ) (STPair _  _ ) _           _           = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS SZ    ) (STEither _  _ ) ()          val         = Just val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Left  idx) (Left  val) = Just (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _  t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _     ) (STEither _  _ ) _           _           = error "onehotD2: either: mismatched index and value" -onehotD2 (SS i     ) (STMaybe t) idx val = Just (onehotD2 i t idx val) -onehotD2 (SS i     ) (STArr n t) idx val = runIdentity $ -  onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val -onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar" -onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" +onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) +onehotD2 SAPHere _ _ val = val +onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) +onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) +onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) +onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) +onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) +onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = +  runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx -withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))  withAccum t _ initval f = AcM $ do -  accum <- newAcSparse t SZ () initval +  accum <- newAcSparse t SAPHere () initval    out <- case f accum of AcM m -> m    val <- readAcSparse t accum    return (out, val) -newAcZero :: STy t -> IO (RepAcSparse t) +newAcZero :: STy t -> IO (RepAc t)  newAcZero = \case    STNil -> return () -  STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2 +  STPair{} -> newIORef Nothing +  STEither{} -> newIORef Nothing    STMaybe _ -> newIORef Nothing    STArr n _ -> newIORef (emptyArray n)    STScal sty -> case sty of -    STI32 -> newIORef 0 -    STI64 -> newIORef 0 +    STI32 -> return () +    STI64 -> return ()      STF32 -> newIORef 0.0      STF64 -> newIORef 0.0 -    STBool -> error "Accumulator of Bool" +    STBool -> return ()    STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator"  -- | Inverted index: the outermost index is at the /outside/ of this list.  data PartialInvIndex n m where @@ -322,95 +315,144 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n  piindexConcat PIIxEnd ix = ix  piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) -newAcSparse typ SZ () val = case typ of -  STNil -> return () -  STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) -  STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val -  STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val -  STScal{} -> newIORef val -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" -newAcSparse typ (SS dep) idx val = case typ of -  STNil -> return () -  STPair t1 t2 -> newIORef =<< case (idx, val) of -                                 (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 -                                 (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' -                                 _ -> error "Index/value mismatch in newAc pair" -  STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val -  STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" +newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) +newAcSparse typ prj idx val = case (typ, prj) of +  (STNil, SAPHere) -> return () +  (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val +  (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val +  (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val +  (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val +  (STScal sty, SAPHere) -> case sty of +    STI32 -> return () +    STI64 -> return () +    STF32 -> newIORef val +    STF64 -> newIORef val +    STBool -> return () -newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t)) -newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n +  (STPair t1 t2, SAPFst prj') -> +    newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 +  (STPair t1 t2, SAPSnd prj') -> +    newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val + +  (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val +  (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val + +  (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val + +  (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val + +  (STAccum{}, _) -> error "Accumulators not allowed in source program" + +newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) +newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx  onehotArray :: Monad m -            => STy t -            -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v)  -- ^ the "one" -            -> m v  -- ^ generate a zero value for elsewhere -            -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v) -onehotArray _ mkone _ _ SZ _ val = -  traverse (mkone SZ ()) val -onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do -  let sh = unTupRepIdx ShNil ShCons dim (fst val) -  go mkone dep dim idx (snd val) $ \arr position -> -    arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of -                               Just i' -> return $ arr `arrayIndex` i' -                               Nothing -> mkzero) -  where -    go :: Monad m -       => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -       -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i) -       -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r -    go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd -    go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd -    go mk (SS dep') (SS dim') (i, idx') val' k = -      go mk dep' dim' idx' val' $ \arr pish -> -        k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) +            => (Rep (AcIdx p a) -> m v)  -- ^ the "one" +            -> m v  -- ^ the "zero" +            -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = +  let arrindex = unTupRepIdx IxNil IxCons n arrindex' +      arrsh = unTupRepIdx ShNil ShCons n arrsh' +  in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero) -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) -newAcDense typ SZ () val = case typ of -  STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) -  STEither t1 t2 -> case val of -    Left x -> Left <$> newAcSparse t1 SZ () x -    Right y -> Right <$> newAcSparse t2 SZ () y -  _ -> error "newAcDense: invalid dense type" -newAcDense typ (SS dep) idx val = case typ of -  STPair t1 t2 -> -    case (idx, val) of -      (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 -      (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' -      _ -> error "Index/value mismatch in newAc pair" -  STEither t1 t2 -> -    case (idx, val) of -      (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' -      (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' -      _ -> error "Index/value mismatch in newAc either" -  _ -> error "newAcDense: invalid dense type" +-- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a)) +-- newAcDense typ SZ () val = case typ of +--   STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) +--   STEither t1 t2 -> case val of +--     Left x -> Left <$> newAcSparse t1 SZ () x +--     Right y -> Right <$> newAcSparse t2 SZ () y +--   _ -> error "newAcDense: invalid dense type" +-- newAcDense typ (SS dep) idx val = case typ of +--   STPair t1 t2 -> +--     case (idx, val) of +--       (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 +--       (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' +--       _ -> error "Index/value mismatch in newAc pair" +--   STEither t1 t2 -> +--     case (idx, val) of +--       (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' +--       (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val' +--       _ -> error "Index/value mismatch in newAc either" +--   _ -> error "newAcDense: invalid dense type" -readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t) +readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))  readAcSparse typ val = case typ of    STNil -> return () -  STPair t1 t2 -> do -    (a, b) <- readIORef val -    (,) <$> readAcSparse t1 a <*> readAcSparse t2 b -  STMaybe t -> traverse (readAcDense t) =<< readIORef val +  STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val +  STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val +  STMaybe t -> traverse (readAcSparse t) =<< readIORef val    STArr _ t -> traverse (readAcSparse t) =<< readIORef val -  STScal{} -> readIORef val +  STScal sty -> case sty of +    STI32 -> return () +    STI64 -> return () +    STF32 -> readIORef val +    STF64 -> readIORef val +    STBool -> return ()    STAccum{} -> error "Nested accumulators" -  STEither{} -> error "Bare Either in accumulator" -readAcDense :: STy t -> RepAcDense t -> IO (Rep t) -readAcDense typ val = case typ of -  STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) -  STEither t1 t2 -> case val of -    Left x -> Left <$> readAcSparse t1 x -    Right y -> Right <$> readAcSparse t2 y -  _ -> error "readAcDense: invalid dense type" +-- readAcDense :: STy t -> RepAcDense t -> IO (Rep t) +-- readAcDense typ val = case typ of +--   STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) +--   STEither t1 t2 -> case val of +--     Left x -> Left <$> readAcSparse t1 x +--     Right y -> Right <$> readAcSparse t2 y +--   _ -> error "readAcDense: invalid dense type" + +accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of +  (STNil, SAPHere) -> return () + +  (STPair t1 t2, SAPHere) -> +    case val of +      Nothing -> return () +      Just (val1, val2) -> +        AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 +                                          <*> newAcSparse t2 SAPHere () val2) +                                     (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1 +                                                        unAcM $ accumAddSparse t2 SAPHere ac2 () val2) +  (STPair t1 t2, SAPFst prj') -> +    AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) +                                 (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val) +  (STPair t1 t2, SAPSnd prj') -> +    AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) +                                 (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val) + +  (STEither t1 t2, SAPHere) -> _ ref val +  (STEither t1 _, SAPLeft prj') -> _ ref idx val +  (STEither _ t2, SAPRight prj') -> _ ref idx val + +  (STMaybe t1, SAPHere) -> _ ref val +  (STMaybe t1, SAPJust prj') -> _ ref idx val + +  (STArr _ t1, SAPHere) -> _ ref val +  (STArr n t, SAPArrIdx prj' _) -> _ ref idx val + +  (STScal sty, SAPHere) -> AcM $ case sty of +    STI32 -> return () +    STI64 -> return () +    STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) +    STBool -> return () + +  (STAccum{}, _) -> error "Accumulators not allowed in source program" + +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO () +realiseMaybeSparse ref makeval modifyval = +  -- Try modifying what's already in ref. The 'join' makes the snd +  -- of the function's return value a _continuation_ that is run after +  -- the critical section ends. +  join $ atomicModifyIORef' ref $ \ac -> case ac of +           -- Oops, ref's contents was still sparse. Have to initialise +           -- it first, then try again. +           Nothing -> (ac, do val <- makeval +                              join $ atomicModifyIORef' ref $ \ac' -> case ac' of +                                       Nothing -> (Just val, return ()) +                                       Just val' -> (ac', modifyval val')) +           -- Yep, ref already had a value in there, so we can just add +           -- val' to it recursively. +           Just val -> (ac, modifyval val) -accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +{-  accumAddSparse typ SZ ref () val = case typ of    STNil -> return ()    STPair t1 t2 -> AcM $ do @@ -532,6 +574,7 @@ accumAddDense typ (SS dep) ref idx val = case typ of        (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')        _ -> error "Mismatched Either in accumAddDense either"    _ -> error "accumAddDense: invalid dense type" +-}  numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index ac06915..f84f4e7 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -21,28 +21,25 @@ type family Rep t where    Rep (TMaybe t) = Maybe (Rep t)    Rep (TArr n t) = Array n (Rep t)    Rep (TScal sty) = ScalRep sty -  Rep (TAccum t) = RepAcSparse t +  Rep (TAccum t) = RepAc t --- Mutable, and has a zero. The zero may not be O(1), but RepAcSparse (D2 t) will have an O(1) zero. -type family RepAcSparse t where -  RepAcSparse TNil = () -  RepAcSparse (TPair a b) = IORef (RepAcSparse a, RepAcSparse b) -  RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid") -  RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t))  -- allow the value to be dense, because the Maybe's zero can be used for the contents +-- Mutable, represents D2 of t. Has an O(1) zero. +type family RepAc t where +  RepAc TNil = () +  RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b)) +  RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) +  RepAc (TMaybe t) = IORef (Maybe (RepAc t))    -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero. -  RepAcSparse (TArr n t) = IORef (Array n (RepAcSparse t))  -- empty array is zero -  RepAcSparse (TScal sty) = IORef (ScalRep sty) -  RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") +  RepAc (TArr n t) = IORef (Array n (RepAc t))  -- empty array is zero +  RepAc (TScal sty) = RepAcScal sty +  RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") --- Immutable, and does not necessarily have a zero. -type family RepAcDense t where -  RepAcDense TNil = () -  RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b) -  RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b) -  -- RepAcDense (TMaybe t) = RepAcSparse (TMaybe t)  -- ^ This can be optimised to TMaybe (RepAcSparse t), but that makes accumAddDense very hard to write. And in any case, we don't need it because D2 will not produce Maybe of Maybe. -  -- RepAcDense (TArr n t) = Array n (RepAcSparse t) -  -- RepAcDense (TScal sty) = ScalRep sty -  -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") +type family RepAcScal t where +  RepAcScal TI32 = () +  RepAcScal TI64 = () +  RepAcScal TF32 = IORef Float +  RepAcScal TF64 = IORef Double +  RepAcScal TBool = ()  newtype Value t = Value { unValue :: Rep t } diff --git a/src/Simplify.hs b/src/Simplify.hs index 2177789..0aa7a66 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -14,6 +14,7 @@ module Simplify (  import Data.Function (fix)  import Data.Monoid (Any(..)) +import Data.Type.Equality (testEquality)  import AST  import AST.Count @@ -105,10 +106,10 @@ simplify' = \case    EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))    -- projection down-commuting -  EFst _ (ECase _ e1 e2@EPair{} e3@EPair{}) -> +  EFst _ (ECase _ e1 e2 e3) ->      acted $ simplify' $        ECase ext e1 (EFst ext e2) (EFst ext e3) -  ESnd _ (ECase _ e1 e2@EPair{} e3@EPair{}) -> +  ESnd _ (ECase _ e1 e2 e3) ->      acted $ simplify' $        ECase ext e1 (ESnd ext e2) (ESnd ext e3) @@ -118,16 +119,22 @@ simplify' = \case    -- TODO: constant folding for operations -  -- TODO: properly concatenate accum/onehot -  EAccum _ SZ _ (EOneHot _ _ i idx val) acc -> -    acted $ simplify' $ -      EAccum ext i idx val acc -  EAccum _ _ _ (EZero _ _) _ -> (Any True, ENil ext) +  -- monoid rules +  EAccum _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) acc +    | Just Refl <- testEquality (acPrjTy prj1 t1) t2 +    -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> +         acted $ simplify' (EAccum ext t1 prj12 idx12 val acc) +  EAccum _ _ _ _ (EZero _ _) _ -> (Any True, ENil ext)    EPlus _ _ (EZero _ _) e -> acted $ simplify' e    EPlus _ _ e (EZero _ _) -> acted $ simplify' e -  EOneHot _ _ SZ _ e -> acted $ simplify' e +  EOneHot _ t _ _ (EZero _ _) -> (Any True, EZero ext t) +  EOneHot _ _ SAPHere _ e -> acted $ simplify' e +  EOneHot _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) +    | Just Refl <- testEquality (acPrjTy prj1 t1) t2 +    -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> +         acted $ simplify' (EOneHot ext t1 prj12 idx12 val) -  -- equations for plus +  -- type-specific equations for plus    EPlus _ STNil _ _ -> (Any True, ENil ext)    EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> @@ -180,8 +187,8 @@ simplify' = \case        <*> (let ?accumInScope = False in simplify' b)        <*> (let ?accumInScope = False in simplify' c)        <*> simplify' e1 <*> simplify' e2 -  EWith _ e1 e2 -> EWith ext <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) -  EAccum _ i e1 e2 e3 -> EAccum ext i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 +  EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) +  EAccum _ t i e1 e2 e3 -> EAccum ext t i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3    EZero _ t -> pure $ EZero ext t    EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b    EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b @@ -230,8 +237,8 @@ hasAdds = \case    EIdx _ a b -> hasAdds a || hasAdds b    EShape _ e -> hasAdds e    EOp _ _ e -> hasAdds e -  EWith _ a b -> hasAdds a || hasAdds b -  EAccum _ _ _ _ _ -> True +  EWith _ _ a b -> hasAdds a || hasAdds b +  EAccum _ _ _ _ _ _ -> True    EZero _ _ -> False    EPlus _ _ a b -> hasAdds a || hasAdds b    EOneHot _ _ _ a b -> hasAdds a || hasAdds b @@ -249,3 +256,27 @@ checkAccumInScope = \case SNil -> False      check (STArr _ t) = check t      check (STScal _) = False      check STAccum{} = True + +concatOneHots :: STy a +              -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) +              -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) +              -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r +concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of +  (_, SAPHere) -> k prj2 idx2 + +  (STPair a _, SAPFst prj1') -> +    concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 +  (STPair _ b, SAPSnd prj1') -> +    concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 + +  (STEither a _, SAPLeft prj1') -> +    concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 +  (STEither _ b, SAPRight prj1') -> +    concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 + +  (STMaybe a, SAPJust prj1') -> +    concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 + +  (STArr n a, SAPArrIdx prj1' _) -> +    concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> +      k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) | 
