diff options
Diffstat (limited to 'src/AST')
-rw-r--r-- | src/AST/Accum.hs | 33 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 29 | ||||
-rw-r--r-- | src/AST/Types.hs | 2 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 119 |
4 files changed, 63 insertions, 120 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 163f1c3..6c46ad5 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -21,16 +21,17 @@ data AcPrj | APArrIdx AcPrj | APArrSlice Nat --- | @b@ is a small part of @a@, indicated by the projection. +-- | @b@ is a small part of @a@, indicated by the projection @p@. data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPHere :: SAcPrj APHere a a - SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b - SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b + SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b + SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b - SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b - SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b - SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) + SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b + SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b + -- TODO: + -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) type family AcIdx p t where @@ -40,5 +41,19 @@ type family AcIdx p t where AcIdx (APLeft p) (TEither a b) = AcIdx p a AcIdx (APRight p) (TEither a b) = AcIdx p b AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a) - AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx) + AcIdx (APArrIdx p) (TArr n a) = + -- ((index, array shape), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) + (AcIdx p a) + -- AcIdx (APArrSlice m) (TArr n a) = + -- -- (index, array shape) + -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) + +acPrjTy :: SAcPrj p a b -> STy a -> STy b +acPrjTy SAPHere t = t +acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index f91aff2..b9406d7 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -235,9 +235,9 @@ ppExpr' d val expr = case expr of ,e1' ,e2'] - EWith _ e1 e2 -> do + EWith _ t e1 e2 -> do e1' <- ppExpr' 11 val e1 - name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 + name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 e2' <- ppExpr' 0 (Const name `SCons` val) e2 return $ ppParen (d > 0) $ group $ flatAlt @@ -247,12 +247,12 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ i e1 e2 e3 -> do + EAccum _ _ prj e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3'] EZero _ t -> return $ ppParen (d > 0) $ annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t @@ -263,11 +263,11 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] - EOneHot _ _ i a b -> do + EOneHot _ _ prj a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b'] + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b'] EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -300,6 +300,15 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") +ppAcPrj :: SAcPrj p a b -> String +ppAcPrj SAPHere = "@" +ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" +ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" +ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" +ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" +ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj +ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) + ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -355,6 +364,14 @@ ppParen :: Bool -> Doc x -> Doc x ppParen True = parens ppParen False = id +intSubscript :: Int -> String +intSubscript = \case 0 -> "₀" + n | n < 0 -> '₋' : go (-n) "" + | otherwise -> go n "" + where go 0 suff = suff + go n suff = let (q, r) = n `quotRem` 10 + in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff) + data Annot = AKey | AWith | AHighlight | AMonoid | AExt deriving (Show) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index be7cffe..0b41671 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -21,7 +21,7 @@ data Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy - | TAccum Ty + | TAccum Ty -- ^ the accumulator contains D2 of this type deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ec5e11e..ae9728a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -2,7 +2,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid where +module AST.UnMonoid (unMonoid, zero, plus) where import AST import CHAD.Types @@ -117,110 +117,21 @@ plusSparse t a b adder = (weakenExpr WSink a) onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) -onehot _ topprj arg = case topprj of - SAPHere -> arg +onehot typ topprj idx arg = case (typ, topprj) of + (_, SAPHere) -> arg - SAPFst prj -> _ + (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) + (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) -onehot t (SS dep) arg = case t of - STPair t1 t2 -> - case dep of - SZ -> EJust ext val - SS dep' -> - let STEither tidx1 tidx2 = typeOf idx - STEither tval1 tval2 = typeOf val - in EJust ext $ - ECase ext idx - (ECase ext (weakenExpr WSink val) - (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) - (zero t2)) - (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) - (ECase ext (weakenExpr WSink val) - (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l") - (EPair ext (zero t1) - (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) + (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) - STEither t1 t2 -> - case dep of - SZ -> EJust ext val - SS dep' -> - let STEither tidx1 tidx2 = typeOf idx - STEither tval1 tval2 = typeOf val - in EJust ext $ - ECase ext idx - (ECase ext (weakenExpr WSink val) - (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) - (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r")) - (ECase ext (weakenExpr WSink val) - (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l") - (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) - STMaybe t1 -> EJust ext (onehot t1 dep idx val) - - STArr n t1 -> - ELet ext val $ - EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) - (onehotArrayElem t1 n (SS dep) - (EVar ext (tTup (sreplicate n tIx)) IZ) - (weakenExpr (WSink .> WSink) idx) - (ESnd ext (EVar ext (typeOf val) (IS IZ)))) - - STNil -> error "Cannot index into nil" - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Accumulators not allowed in input program" - --- onehotArrayElem --- :: STy t -> SNat n -> SNat i --- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' --- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot --- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole --- -> Ex env (D2 t) --- onehotArrayElem t n dep eltidx idx val = --- ELet ext eltidx $ --- ELet ext (weakenExpr WSink idx) $ --- let (cond, elt) = onehotArrayElemRec t n dep --- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) --- (EVar ext (typeOf idx) IZ) --- (weakenExpr (WSink .> WSink) val) --- in eif cond elt (zero t) - --- -- AcIdx must be duplicable --- onehotArrayElemRec --- :: STy t -> SNat n -> SNat i --- -> [Ex env TIx] --- -> Ex env (AcIdx (TArr n (D2 t)) i) --- -> Ex env (AcValArr n (D2 t) i) --- -> (Ex env (TScal TBool), Ex env (D2 t)) --- onehotArrayElemRec _ n SZ eltidx _ val = --- (EConst ext STBool True --- ,EIdx ext val (reconstructFromOutsideIn n eltidx)) --- onehotArrayElemRec t SZ (SS dep) eltidx idx val = --- case eltidx of --- [] -> (EConst ext STBool True, onehot t dep idx val) --- _ -> error "onehotArrayElemRec: mismatched list length" --- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = --- case eltidx of --- i : eltidx' -> --- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val --- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) --- ,elt) --- [] -> error "onehotArrayElemRec: mismatched list length" - --- | Outermost index at the head. The input expression must be duplicable. -outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -outsideInIndex = \n idx -> go n idx [] - where - go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] - go SZ _ acc = acc - go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) - --- Takes a list with the outermost index at the head. Returns a tuple with the --- innermost index on the right. -reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) -reconstructFromOutsideIn = \n list -> go n (reverse list) - where - -- Takes list with the _innermost_ index at the head. - go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) - go SZ [] = ENil ext - go (SS n) (i:is) = EPair ext (go n is) i - go _ _ = error "reconstructFromOutsideIn: mismatched list length" + (STArr n t1, SAPArrIdx prj _) -> + let tidx = tTup (sreplicate n tIx) + in ELet ext idx $ + EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (zero t1) |