summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Accum.hs33
-rw-r--r--src/AST/Pretty.hs29
-rw-r--r--src/AST/Types.hs2
-rw-r--r--src/AST/UnMonoid.hs119
4 files changed, 63 insertions, 120 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 163f1c3..6c46ad5 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -21,16 +21,17 @@ data AcPrj
| APArrIdx AcPrj
| APArrSlice Nat
--- | @b@ is a small part of @a@, indicated by the projection.
+-- | @b@ is a small part of @a@, indicated by the projection @p@.
data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPHere :: SAcPrj APHere a a
- SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b
- SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b
- SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b
+ SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
+ SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
- SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b
- SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b
- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
+ SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
+ SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
+ -- TODO:
+ -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
type family AcIdx p t where
@@ -40,5 +41,19 @@ type family AcIdx p t where
AcIdx (APLeft p) (TEither a b) = AcIdx p a
AcIdx (APRight p) (TEither a b) = AcIdx p b
AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a)
- AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx)
+ AcIdx (APArrIdx p) (TArr n a) =
+ -- ((index, array shape), recursive info)
+ TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx)))
+ (AcIdx p a)
+ -- AcIdx (APArrSlice m) (TArr n a) =
+ -- -- (index, array shape)
+ -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+
+acPrjTy :: SAcPrj p a b -> STy a -> STy b
+acPrjTy SAPHere t = t
+acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t
+acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t
+acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t
+acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t
+acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t
+acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index f91aff2..b9406d7 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -235,9 +235,9 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
- EWith _ e1 e2 -> do
+ EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
- name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
+ name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
e2' <- ppExpr' 0 (Const name `SCons` val) e2
return $ ppParen (d > 0) $
group $ flatAlt
@@ -247,12 +247,12 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ i e1 e2 e3 -> do
+ EAccum _ _ prj e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']
EZero _ t -> return $ ppParen (d > 0) $
annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t
@@ -263,11 +263,11 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
- EOneHot _ _ i a b -> do
+ EOneHot _ _ prj a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b']
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
@@ -300,6 +300,15 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
+ppAcPrj :: SAcPrj p a b -> String
+ppAcPrj SAPHere = "@"
+ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)"
+ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")"
+ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)"
+ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")"
+ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj
+ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n)
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -355,6 +364,14 @@ ppParen :: Bool -> Doc x -> Doc x
ppParen True = parens
ppParen False = id
+intSubscript :: Int -> String
+intSubscript = \case 0 -> "₀"
+ n | n < 0 -> '₋' : go (-n) ""
+ | otherwise -> go n ""
+ where go 0 suff = suff
+ go n suff = let (q, r) = n `quotRem` 10
+ in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff)
+
data Annot = AKey | AWith | AHighlight | AMonoid | AExt
deriving (Show)
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index be7cffe..0b41671 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -21,7 +21,7 @@ data Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TAccum Ty
+ | TAccum Ty -- ^ the accumulator contains D2 of this type
deriving (Show, Eq, Ord)
data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ec5e11e..ae9728a 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -2,7 +2,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid where
+module AST.UnMonoid (unMonoid, zero, plus) where
import AST
import CHAD.Types
@@ -117,110 +117,21 @@ plusSparse t a b adder =
(weakenExpr WSink a)
onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
-onehot _ topprj arg = case topprj of
- SAPHere -> arg
+onehot typ topprj idx arg = case (typ, topprj) of
+ (_, SAPHere) -> arg
- SAPFst prj -> _
+ (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
+ (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
-onehot t (SS dep) arg = case t of
- STPair t1 t2 ->
- case dep of
- SZ -> EJust ext val
- SS dep' ->
- let STEither tidx1 tidx2 = typeOf idx
- STEither tval1 tval2 = typeOf val
- in EJust ext $
- ECase ext idx
- (ECase ext (weakenExpr WSink val)
- (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
- (zero t2))
- (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
- (ECase ext (weakenExpr WSink val)
- (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
- (EPair ext (zero t1)
- (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+ (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
+ (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
- STEither t1 t2 ->
- case dep of
- SZ -> EJust ext val
- SS dep' ->
- let STEither tidx1 tidx2 = typeOf idx
- STEither tval1 tval2 = typeOf val
- in EJust ext $
- ECase ext idx
- (ECase ext (weakenExpr WSink val)
- (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
- (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
- (ECase ext (weakenExpr WSink val)
- (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l")
- (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+ (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
- STMaybe t1 -> EJust ext (onehot t1 dep idx val)
-
- STArr n t1 ->
- ELet ext val $
- EBuild ext n (EFst ext (EVar ext (typeOf val) IZ))
- (onehotArrayElem t1 n (SS dep)
- (EVar ext (tTup (sreplicate n tIx)) IZ)
- (weakenExpr (WSink .> WSink) idx)
- (ESnd ext (EVar ext (typeOf val) (IS IZ))))
-
- STNil -> error "Cannot index into nil"
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Accumulators not allowed in input program"
-
--- onehotArrayElem
--- :: STy t -> SNat n -> SNat i
--- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
--- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
--- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
--- -> Ex env (D2 t)
--- onehotArrayElem t n dep eltidx idx val =
--- ELet ext eltidx $
--- ELet ext (weakenExpr WSink idx) $
--- let (cond, elt) = onehotArrayElemRec t n dep
--- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
--- (EVar ext (typeOf idx) IZ)
--- (weakenExpr (WSink .> WSink) val)
--- in eif cond elt (zero t)
-
--- -- AcIdx must be duplicable
--- onehotArrayElemRec
--- :: STy t -> SNat n -> SNat i
--- -> [Ex env TIx]
--- -> Ex env (AcIdx (TArr n (D2 t)) i)
--- -> Ex env (AcValArr n (D2 t) i)
--- -> (Ex env (TScal TBool), Ex env (D2 t))
--- onehotArrayElemRec _ n SZ eltidx _ val =
--- (EConst ext STBool True
--- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
--- onehotArrayElemRec t SZ (SS dep) eltidx idx val =
--- case eltidx of
--- [] -> (EConst ext STBool True, onehot t dep idx val)
--- _ -> error "onehotArrayElemRec: mismatched list length"
--- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
--- case eltidx of
--- i : eltidx' ->
--- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
--- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
--- ,elt)
--- [] -> error "onehotArrayElemRec: mismatched list length"
-
--- | Outermost index at the head. The input expression must be duplicable.
-outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
-outsideInIndex = \n idx -> go n idx []
- where
- go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx]
- go SZ _ acc = acc
- go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc)
-
--- Takes a list with the outermost index at the head. Returns a tuple with the
--- innermost index on the right.
-reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
-reconstructFromOutsideIn = \n list -> go n (reverse list)
- where
- -- Takes list with the _innermost_ index at the head.
- go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
- go SZ [] = ENil ext
- go (SS n) (i:is) = EPair ext (go n is) i
- go _ _ = error "reconstructFromOutsideIn: mismatched list length"
+ (STArr n t1, SAPArrIdx prj _) ->
+ let tidx = tTup (sreplicate n tIx)
+ in ELet ext idx $
+ EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (zero t1)