diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Accum.hs | 33 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 29 | ||||
| -rw-r--r-- | src/AST/Types.hs | 2 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 119 | 
4 files changed, 63 insertions, 120 deletions
| diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 163f1c3..6c46ad5 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -21,16 +21,17 @@ data AcPrj    | APArrIdx AcPrj    | APArrSlice Nat --- | @b@ is a small part of @a@, indicated by the projection. +-- | @b@ is a small part of @a@, indicated by the projection @p@.  data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where    SAPHere :: SAcPrj APHere a a -  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b -  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b -  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b +  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b +  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b +  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b    SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b -  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b -  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b -  SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b +  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b +  -- TODO: +  -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)  deriving instance Show (SAcPrj p a b)  type family AcIdx p t where @@ -40,5 +41,19 @@ type family AcIdx p t where    AcIdx (APLeft p) (TEither a b) = AcIdx p a    AcIdx (APRight p) (TEither a b) = AcIdx p b    AcIdx (APJust p) (TMaybe a) = AcIdx p a -  AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) -  AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) +  AcIdx (APArrIdx p) (TArr n a) = +    -- ((index, array shape), recursive info) +    TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) +          (AcIdx p a) +  -- AcIdx (APArrSlice m) (TArr n a) = +  --   -- (index, array shape) +  --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) + +acPrjTy :: SAcPrj p a b -> STy a -> STy b +acPrjTy SAPHere t = t +acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index f91aff2..b9406d7 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -235,9 +235,9 @@ ppExpr' d val expr = case expr of          ,e1'          ,e2'] -  EWith _ e1 e2 -> do +  EWith _ t e1 e2 -> do      e1' <- ppExpr' 11 val e1 -    name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 +    name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2      e2' <- ppExpr' 0 (Const name `SCons` val) e2      return $ ppParen (d > 0) $        group $ flatAlt @@ -247,12 +247,12 @@ ppExpr' d val expr = case expr of             <> hardline <> e2')          (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) -  EAccum _ i e1 e2 e3 -> do +  EAccum _ _ prj e1 e2 e3 -> do      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2      e3' <- ppExpr' 11 val e3      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] +      ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']    EZero _ t -> return $ ppParen (d > 0) $      annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t @@ -263,11 +263,11 @@ ppExpr' d val expr = case expr of      return $ ppParen (d > 10) $        ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] -  EOneHot _ _ i a b -> do +  EOneHot _ _ prj a b -> do      a' <- ppExpr' 11 val a      b' <- ppExpr' 11 val b      return $ ppParen (d > 10) $ -      ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b'] +      ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']    EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -300,6 +300,15 @@ ppLam :: [ADoc] -> ADoc -> ADoc  ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])                                            <> softline <> body <> ppString ")") +ppAcPrj :: SAcPrj p a b -> String +ppAcPrj SAPHere = "@" +ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" +ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" +ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" +ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" +ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj +ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +  ppX :: PrettyX x => Expr x env t -> ADoc  ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -355,6 +364,14 @@ ppParen :: Bool -> Doc x -> Doc x  ppParen True = parens  ppParen False = id +intSubscript :: Int -> String +intSubscript = \case 0 -> "₀" +                     n | n < 0 -> '₋' : go (-n) "" +                       | otherwise -> go n "" +  where go 0 suff = suff +        go n suff = let (q, r) = n `quotRem` 10 +                    in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff) +  data Annot = AKey | AWith | AHighlight | AMonoid | AExt    deriving (Show) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index be7cffe..0b41671 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -21,7 +21,7 @@ data Ty    | TMaybe Ty    | TArr Nat Ty  -- ^ rank, element type    | TScal ScalTy -  | TAccum Ty +  | TAccum Ty  -- ^ the accumulator contains D2 of this type    deriving (Show, Eq, Ord)  data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ec5e11e..ae9728a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -2,7 +2,7 @@  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid where +module AST.UnMonoid (unMonoid, zero, plus) where  import AST  import CHAD.Types @@ -117,110 +117,21 @@ plusSparse t a b adder =        (weakenExpr WSink a)  onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) -onehot _ topprj arg = case topprj of -  SAPHere -> arg +onehot typ topprj idx arg = case (typ, topprj) of +  (_, SAPHere) -> arg -  SAPFst prj -> _ +  (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) +  (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) -onehot t (SS dep) arg = case t of -  STPair t1 t2 -> -    case dep of -      SZ -> EJust ext val -      SS dep' -> -        let STEither tidx1 tidx2 = typeOf idx -            STEither tval1 tval2 = typeOf val -        in EJust ext $ -            ECase ext idx -              (ECase ext (weakenExpr WSink val) -                (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) -                           (zero t2)) -                (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) -              (ECase ext (weakenExpr WSink val) -                (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l") -                (EPair ext (zero t1) -                           (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) +  (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) +  (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) -  STEither t1 t2 -> -    case dep of -      SZ -> EJust ext val -      SS dep' -> -        let STEither tidx1 tidx2 = typeOf idx -            STEither tval1 tval2 = typeOf val -        in EJust ext $ -            ECase ext idx -              (ECase ext (weakenExpr WSink val) -                (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) -                (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r")) -              (ECase ext (weakenExpr WSink val) -                (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l") -                (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) +  (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) -  STMaybe t1 -> EJust ext (onehot t1 dep idx val) - -  STArr n t1 -> -    ELet ext val $ -      EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) -                   (onehotArrayElem t1 n (SS dep) -                                    (EVar ext (tTup (sreplicate n tIx)) IZ) -                                    (weakenExpr (WSink .> WSink) idx) -                                    (ESnd ext (EVar ext (typeOf val) (IS IZ)))) - -  STNil -> error "Cannot index into nil" -  STScal{} -> error "Cannot index into scalar" -  STAccum{} -> error "Accumulators not allowed in input program" - --- onehotArrayElem ---   :: STy t -> SNat n -> SNat i ---   -> Ex env (Tup (Replicate n TIx))    -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' ---   -> Ex env (AcIdx (TArr n (D2 t)) i)  -- ^ where to put the one-hot ---   -> Ex env (AcValArr n (D2 t) i)      -- ^ value to put in the hole ---   -> Ex env (D2 t) --- onehotArrayElem t n dep eltidx idx val = ---   ELet ext eltidx $ ---   ELet ext (weakenExpr WSink idx) $ ---     let (cond, elt) = onehotArrayElemRec t n dep ---                                          (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) ---                                          (EVar ext (typeOf idx) IZ) ---                                          (weakenExpr (WSink .> WSink) val) ---     in eif cond elt (zero t) - --- -- AcIdx must be duplicable --- onehotArrayElemRec ---   :: STy t -> SNat n -> SNat i ---   -> [Ex env TIx] ---   -> Ex env (AcIdx (TArr n (D2 t)) i) ---   -> Ex env (AcValArr n (D2 t) i) ---   -> (Ex env (TScal TBool), Ex env (D2 t)) --- onehotArrayElemRec _ n SZ eltidx _ val = ---   (EConst ext STBool True ---   ,EIdx ext val (reconstructFromOutsideIn n eltidx)) --- onehotArrayElemRec t SZ (SS dep) eltidx idx val = ---   case eltidx of ---     [] -> (EConst ext STBool True, onehot t dep idx val) ---     _ -> error "onehotArrayElemRec: mismatched list length" --- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = ---   case eltidx of ---     i : eltidx' -> ---       let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val ---       in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) ---          ,elt) ---     [] -> error "onehotArrayElemRec: mismatched list length" - --- | Outermost index at the head. The input expression must be duplicable. -outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -outsideInIndex = \n idx -> go n idx [] -  where -    go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] -    go SZ _ acc = acc -    go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) - --- Takes a list with the outermost index at the head. Returns a tuple with the --- innermost index on the right. -reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) -reconstructFromOutsideIn = \n list -> go n (reverse list) -  where -    -- Takes list with the _innermost_ index at the head. -    go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) -    go SZ [] = ENil ext -    go (SS n) (i:is) = EPair ext (go n is) i -    go _ _ = error "reconstructFromOutsideIn: mismatched list length" +  (STArr n t1, SAPArrIdx prj _) -> +    let tidx = tTup (sreplicate n tIx) +    in ELet ext idx $ +         EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ +           eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) +             (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) +             (zero t1) | 
