diff --git a/sdks/python/apache_beam/dataframe/io.py b/sdks/python/apache_beam/dataframe/io.py index 02423f517eea..b4cdafdf7ed1 100644 --- a/sdks/python/apache_beam/dataframe/io.py +++ b/sdks/python/apache_beam/dataframe/io.py @@ -680,17 +680,26 @@ def __init__( self.binary = binary def expand(self, pcoll): - if 'file_naming' in self.kwargs: + kwargs = dict(self.kwargs) + if 'file_naming' in kwargs: dir, name = self.path, '' else: dir, name = io.filesystems.FileSystems.split(self.path) + num_shards = kwargs.pop('num_shards', None) + max_writers_per_bundle = kwargs.pop('max_writers_per_bundle', None) + write_to_files_kwargs = {} + if num_shards is not None: + write_to_files_kwargs['shards'] = num_shards + write_to_files_kwargs['max_writers_per_bundle'] = ( + max_writers_per_bundle if max_writers_per_bundle is not None else 0) + elif max_writers_per_bundle is not None: + write_to_files_kwargs['max_writers_per_bundle'] = max_writers_per_bundle return pcoll | fileio.WriteToFiles( path=dir, - shards=self.kwargs.pop('num_shards', None), - file_naming=self.kwargs.pop( - 'file_naming', fileio.default_file_naming(name)), + file_naming=kwargs.pop('file_naming', fileio.default_file_naming(name)), sink=lambda _: _WriteToPandasFileSink( - self.writer, self.args, self.kwargs, self.incremental, self.binary)) + self.writer, self.args, kwargs, self.incremental, self.binary), + **write_to_files_kwargs) class _WriteToPandasFileSink(fileio.FileSink):