From 4d456e4d34b1e4fb3725051d1b8a0c376b704692 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:56:35 +0100 Subject: Implement reshape --- src/AST.hs | 5 +++++ src/AST/Count.hs | 15 +++++++++++++++ src/AST/Pretty.hs | 5 +++++ src/AST/SplitLets.hs | 1 + src/AST/UnMonoid.hs | 1 + src/Analysis/Identity.hs | 6 ++++++ src/Array.hs | 5 +++++ src/CHAD.hs | 15 +++++++++++++++ src/Compile.hs | 24 +++++++++++++++++++++++- src/ForwardAD/DualNumbers.hs | 2 ++ src/Interpreter.hs | 4 ++++ src/Language.hs | 3 +++ src/Language/AST.hs | 2 ++ src/Simplify.hs | 2 ++ 14 files changed, 89 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/AST.hs b/src/AST.hs index f7b63cf..7549ff0 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -69,6 +69,7 @@ data Expr x env t where EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) -- MapAccum-like (is it real mapaccum? If so, rename) EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative @@ -233,6 +234,7 @@ typeOf = \case EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) EFold1InnerD2 _ _ _ e2 _ e4 _ | t2 <- typeOf e2, STArr sn _ <- typeOf e4 -> STPair t2 (STArr sn t2) @@ -283,6 +285,7 @@ extOf = \case EReplicate1Inner x _ _ -> x EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x EFold1InnerD1 x _ _ _ _ -> x EFold1InnerD2 x _ _ _ _ _ _ -> x EConst x _ _ -> x @@ -331,6 +334,7 @@ travExt f = \case EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f d <*> travExt f e EConst x t v -> EConst <$> f x <*> pure t <*> pure v @@ -392,6 +396,7 @@ subst' f w = \case EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EFold1InnerD2 x cm a b c d e -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) c) (subst' f w d) (subst' f w e) EConst x t v -> EConst x t v diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 229661f..66b4e0b 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -598,6 +598,21 @@ occCountX initialS topexpr k = case topexpr of EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + EFold1InnerD1 _ cm e1 e2 e3 -> case s of -- If nothing is necessary, we can execute a fold and then proceed to ignore it diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 587328d..67197f9 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -235,6 +235,11 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' + EFold1InnerD1 _ cm a b c -> do name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a name2 <- genNameIfUsedIn (typeOf b) IZ a diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 6034084..73c1c67 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -61,6 +61,7 @@ splitLets' = \sub -> \case EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 6904715..e5a9708 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -44,6 +44,7 @@ unMonoid = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c d e -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid d) (unMonoid e) EConst _ t x -> EConst ext t x diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 9dc8811..b3a6664 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -244,6 +244,12 @@ idana env expr = case expr of res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + EFold1InnerD1 _ cm e1 e2 e3 -> do let t1 = typeOf e2 x1 <- genIds t1 diff --git a/src/Array.hs b/src/Array.hs index 707dce2..6ceb9fe 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -91,6 +91,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) arrayToList :: Array n t -> [t] arrayToList (Array _ v) = V.toList v +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/src/CHAD.hs b/src/CHAD.hs index 93fabf9..04c4231 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1327,6 +1327,21 @@ drev des accumMap sd = \case EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" diff --git a/src/Compile.hs b/src/Compile.hs index 0ab7ea4..4e81c6a 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -920,6 +920,25 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -1323,7 +1342,7 @@ compileShapeQuery (SS n) var = -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") (compileArrShapeComponents n var) +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) -- | Takes a variable name for the array, not the buffer. compileArrShapeComponents :: SNat n -> String -> [CExpr] @@ -1347,6 +1366,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 467b895..44bdbb2 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -168,6 +168,8 @@ dfwdDN = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index db7033d..79d5014 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -146,6 +146,10 @@ interpret'Rec env = \case sh `ShCons` n = arrayShape arr numericIsNum t $ return $ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr EFold1InnerD1 _ _ a b c -> do let t = typeOf b let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a diff --git a/src/Language.hs b/src/Language.hs index 4e6d604..d3c38d6 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -130,6 +130,9 @@ maximum1i e = NEMaximum1Inner e minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) minimum1i e = NEMinimum1Inner e +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy diff --git a/src/Language/AST.hs b/src/Language/AST.hs index be98ccf..325817d 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -56,6 +56,7 @@ data NExpr env t where NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -205,6 +206,7 @@ fromNamedExpr val = \case NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) diff --git a/src/Simplify.hs b/src/Simplify.hs index aac9963..74306a1 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -314,6 +314,7 @@ simplify'Rec = \case EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EReshape _ n a b -> [simprec| EReshape ext n *a *b |] EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] EFold1InnerD2 _ cm a b c d e -> [simprec| EFold1InnerD2 ext cm *a *b *c *d *e |] EConst _ t v -> pure $ EConst ext t v @@ -369,6 +370,7 @@ hasAdds = \case EReplicate1Inner _ a b -> hasAdds a || hasAdds b EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e + EReshape _ _ a b -> hasAdds a || hasAdds b EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c EFold1InnerD2 _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e -- cgit v1.2.3-70-g09d2