diff --git a/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml b/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml index ea1c255f7cc9..afa437b64de1 100644 --- a/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml +++ b/.github/workflows/beam_PostCommit_Yaml_Xlang_Direct.yml @@ -80,7 +80,7 @@ jobs: - name: run PostCommit Yaml Xlang Direct script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:python:postCommitYamlIntegrationTests -PyamlTestSet=${{ matrix.test_set }} -PbeamPythonExtra=ml_test,yaml + gradle-command: :sdks:python:postCommitYamlIntegrationTests -PyamlTestSet=${{ matrix.test_set }} - name: Archive Python Test Results uses: actions/upload-artifact@v7 if: failure() diff --git a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml index 7d17fd2140c9..0b8f4cd63939 100644 --- a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml +++ b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml @@ -91,7 +91,7 @@ jobs: - name: run PreCommit Yaml Xlang Direct script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:python:yamlIntegrationTests -PbeamPythonExtra=ml_test,yaml + gradle-command: :sdks:python:yamlIntegrationTests - name: Archive Python Test Results uses: actions/upload-artifact@v7 if: failure() diff --git a/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml b/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml new file mode 100644 index 000000000000..8728a6f544ad --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pipelines: + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - text: "I love Apache Beam!" + - text: "I hate this error." + - type: RunInference + config: + model_handler: + type: "HuggingFacePipeline" + config: + task: "text-classification" + inference_fn: + callable: | + def real_inference(batch, pipeline, inference_args): + predictions = pipeline(batch, **inference_args) + + # If it's a single dictionary (batch size of 1), wrap it in a list + if isinstance(predictions, dict): + predictions = [predictions] + + return { + 'label': [p['label'] for p in predictions], + 'score': [p['score'] for p in predictions] + } + preprocess: + callable: 'lambda x: x.text' + - type: MapToFields + config: + language: python + fields: + text: text + sentiment: + callable: 'lambda x: x.inference.inference["label"]' + - type: AssertEqual + config: + elements: + - text: "I love Apache Beam!" + sentiment: "POSITIVE" + - text: "I hate this error." + sentiment: "NEGATIVE" + + options: + yaml_experimental_features: ['ML'] diff --git a/sdks/python/apache_beam/yaml/tests/runinference.yaml b/sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml similarity index 100% rename from sdks/python/apache_beam/yaml/tests/runinference.yaml rename to sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 51f18c733046..05cbed3bd456 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -282,6 +282,55 @@ def inference_output_type(self): ('model_id', Optional[str])]) +@ModelHandlerProvider.register_handler_type('HuggingFacePipeline') +class HuggingFacePipelineProvider(ModelHandlerProvider): + def __init__( + self, + task: Optional[str] = None, + model: Optional[str] = None, + preprocess: Optional[dict[str, str]] = None, + postprocess: Optional[dict[str, str]] = None, + device: Optional[Any] = None, + inference_fn: Optional[dict[str, str]] = None, + load_pipeline_args: Optional[dict[str, Any]] = None, + **kwargs): + try: + from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler + except ImportError: + raise ValueError( + 'Unable to import HuggingFacePipelineModelHandler. Please ' + 'install transformers dependencies.') + + kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')} + + inference_fn_obj = self.parse_processing_transform( + inference_fn, 'inference_fn') if inference_fn else None + + handler_kwargs = {} + if inference_fn_obj: + handler_kwargs['inference_fn'] = inference_fn_obj + + _handler = HuggingFacePipelineModelHandler( + task=task, + model=model, + device=device, + load_pipeline_args=load_pipeline_args, + **handler_kwargs, + **kwargs) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(config): + if not config.get('task') and not config.get('model'): + raise ValueError( + "HuggingFacePipeline requires either 'task' or " + "'model' to be specified.") + + def inference_output_type(self): + return Any + + @beam.ptransform.ptransform_fn def run_inference( pcoll, diff --git a/sdks/python/build.gradle b/sdks/python/build.gradle index 5f09dff57e8f..e676fd110433 100644 --- a/sdks/python/build.gradle +++ b/sdks/python/build.gradle @@ -124,10 +124,20 @@ tasks.register("generateYamlDocs") { outputs.file "${buildDir}/yaml-examples.html" } +tasks.register("installYamlIntegrationTestDeps") { + dependsOn installGcpTest + doLast { + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install --pre --retries 10 ${buildDir}/apache-beam.tar.gz[ml_test,yaml,transformers]" + } + } +} + tasks.register("yamlIntegrationTests") { description "Runs precommit integration tests for yaml pipelines." - dependsOn installGcpTest + dependsOn installYamlIntegrationTestDeps // Need to build all expansion services referenced in apache_beam/yaml/*.* // grep -oh 'sdk.*Jar' sdks/python/apache_beam/yaml/*.yaml | sort | uniq dependsOn ":sdks:java:extensions:schemaio-expansion-service:shadowJar" @@ -146,7 +156,7 @@ tasks.register("yamlIntegrationTests") { tasks.register("postCommitYamlIntegrationTests") { description "Runs postcommit integration tests for yaml pipelines - parameterized by yamlTestSet." - dependsOn installGcpTest + dependsOn installYamlIntegrationTestDeps // Need to build all expansion services referenced in apache_beam/yaml/*.* // grep -oh 'sdk.*Jar' sdks/python/apache_beam/yaml/*.yaml | sort | uniq dependsOn ":sdks:java:extensions:schemaio-expansion-service:shadowJar"