aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs5
-rw-r--r--src/AST/Count.hs15
-rw-r--r--src/AST/Pretty.hs5
-rw-r--r--src/AST/SplitLets.hs1
-rw-r--r--src/AST/UnMonoid.hs1
-rw-r--r--src/Analysis/Identity.hs6
-rw-r--r--src/Array.hs5
-rw-r--r--src/CHAD.hs15
-rw-r--r--src/Compile.hs24
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs4
-rw-r--r--src/Language.hs3
-rw-r--r--src/Language/AST.hs2
-rw-r--r--src/Simplify.hs2
14 files changed, 89 insertions, 1 deletions
diff --git a/src/AST.hs b/src/AST.hs
index f7b63cf..7549ff0 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -69,6 +69,7 @@ data Expr x env t where
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
-- MapAccum-like (is it real mapaccum? If so, rename)
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
@@ -233,6 +234,7 @@ typeOf = \case
EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2)
@@ -283,6 +285,7 @@ extOf = \case
EReplicate1Inner x _ _ -> x
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
+ EReshape x _ _ _ -> x
EFold1InnerD1 x _ _ _ _ -> x
EFold1InnerD2 x _ _ _ _ _ _ -> x
EConst x _ _ -> x
@@ -331,6 +334,7 @@ travExt f = \case
EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e
EConst x t v -> EConst <$> f x <*> pure t <*> pure v
@@ -392,6 +396,7 @@ subst' f w = \case
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
+ EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e)
EConst x t v -> EConst x t v
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 229661f..66b4e0b 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -598,6 +598,21 @@ occCountX initialS topexpr k = case topexpr of
EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+ EReshape _ n esh e ->
+ case s of
+ SsNone ->
+ occCountX SsNone esh $ \env1 mkesh ->
+ occCountX SsNone e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkesh env') $ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull esh $ \env1 mkesh ->
+ occCountX (SsArr s') e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReshape ext n (mkesh env') (mke env')
+
EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
-- If nothing is necessary, we can execute a fold and then proceed to ignore it
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 587328d..67197f9 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -235,6 +235,11 @@ ppExpr' d val expr = case expr of
e' <- ppExpr' 11 val e
return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+ EReshape _ n esh e -> do
+ esh' <- ppExpr' 11 val esh
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e'
+
EFold1InnerD1 _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf b) IZ a
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 6034084..73c1c67 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -61,6 +61,7 @@ splitLets' = \sub -> \case
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 6904715..e5a9708 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -44,6 +44,7 @@ unMonoid = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e)
EConst _ t x -> EConst ext t x
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 9dc8811..b3a6664 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -244,6 +244,12 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> pure sh
pure (res, EMinimum1Inner res e1')
+ EReshape _ dim e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ res <- VIArr <$> genId <*> shidsToVec dim v1
+ pure (res, EReshape res dim e1' e2')
+
EFold1InnerD1 _ cm e1 e2 e3 -> do
let t1 = typeOf e2
x1 <- genIds t1
diff --git a/src/Array.hs b/src/Array.hs
index 707dce2..6ceb9fe 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -91,6 +91,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l)
arrayToList :: Array n t -> [t]
arrayToList (Array _ v) = V.toList v
+arrayReshape :: Shape n -> Array m t -> Array n t
+arrayReshape sh (Array sh' v)
+ | shapeSize sh == shapeSize sh' = Array sh v
+ | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")"
+
arrayUnit :: t -> Array Z t
arrayUnit x = Array ShNil (V.singleton x)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 93fabf9..04c4231 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1327,6 +1327,21 @@ drev des accumMap sd = \case
EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
+ EReshape _ n esh e
+ | SpArr sd' <- sd
+ , STArr orign t <- typeOf e
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
+ , Refl <- indexTupD1Id n ->
+ Ret (e0 `bpush` e1
+ `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ))
+ (SEYesR (SENo subtape))
+ (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh))
+ (EVar ext (STArr orign (d1 t)) (IS IZ)))
+ sub
+ (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
+
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"
diff --git a/src/Compile.hs b/src/Compile.hs
index 0ab7ea4..4e81c6a 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -920,6 +920,25 @@ compile' env = \case
EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e
+ EReshape _ dim esh earg -> do
+ let STArr origDim eltty = typeOf earg
+ strname <- emitStruct (STArr dim eltty)
+
+ shname <- compileAssign "reshsh" env esh
+ arrname <- compileAssign "resharg" env earg
+
+ when emitChecks $ do
+ emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname))))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++
+ printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++
+ printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;")
+ mempty
+
+ return (CEStruct strname
+ [("buf", CEProj (CELit arrname) "buf")
+ ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+
EConst _ t x -> return $ CELit $ compileScal True t x
EIdx0 _ e -> do
@@ -1323,7 +1342,7 @@ compileShapeQuery (SS n) var =
-- | Takes a variable name for the array, not the buffer.
compileArrShapeSize :: SNat n -> String -> CExpr
-compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var)
+compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var)
-- | Takes a variable name for the array, not the buffer.
compileArrShapeComponents :: SNat n -> String -> [CExpr]
@@ -1347,6 +1366,9 @@ shapeTupFromLitVars = \n -> go n . reverse
go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)]
go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond"
+prodExpr :: [CExpr] -> CExpr
+prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1")
+
compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr
compileOpGeneral op e1 = do
let unary cop = return @CompM $ CECall cop [e1]
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 467b895..44bdbb2 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -168,6 +168,8 @@ dfwdDN = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EReshape _ n esh e
+ | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index db7033d..79d5014 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -146,6 +146,10 @@ interpret'Rec env = \case
sh `ShCons` n = arrayShape arr
numericIsNum t $ return $
arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EReshape _ dim esh e -> do
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh
+ arr <- interpret' env e
+ return $ arrayReshape sh arr
EFold1InnerD1 _ _ a b c -> do
let t = typeOf b
let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
diff --git a/src/Language.hs b/src/Language.hs
index 4e6d604..d3c38d6 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -130,6 +130,9 @@ maximum1i e = NEMaximum1Inner e
minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
minimum1i e = NEMinimum1Inner e
+reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
+reshape = NEReshape
+
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index be98ccf..325817d 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -56,6 +56,7 @@ data NExpr env t where
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
@@ -205,6 +206,7 @@ fromNamedExpr val = \case
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
NEMaximum1Inner e -> EMaximum1Inner ext (go e)
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
+ NEReshape n a b -> EReshape ext n (go a) (go b)
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index aac9963..74306a1 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -314,6 +314,7 @@ simplify'Rec = \case
EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
+ EReshape _ n a b -> [simprec| EReshape ext n *a *b |]
EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
EFold1InnerD2 _ cm a b c d e -> [simprec| EFold1InnerD2 ext cm *a *b *c *d *e |]
EConst _ t v -> pure $ EConst ext t v
@@ -369,6 +370,7 @@ hasAdds = \case
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
EMaximum1Inner _ e -> hasAdds e
EMinimum1Inner _ e -> hasAdds e
+ EReshape _ _ a b -> hasAdds a || hasAdds b
EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
EFold1InnerD2 _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e