-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Python] Hugging Face pipeline support (#27399)
* automodel first pass * new model * updated model handler api * add model_class param * update doc comments * updated integration test and example * unit test, modified params * add test setup for hugging face tests * fix lints * fix import order * refactor, doc, lints * refactor, doc comments * change test file * update types * add hugging face pipeline support * integration test for pipeline * add doc, gs link * test raises exception * fix python lints * add inference fn * update doc * docs, lint * docs, lint * remove optional from inference_fn * add enum for tasks * update pydoc * update pydoc * doc, formatting changes * fix doc * fix optional in doc * pin model version
- Loading branch information
1 parent
9c80b49
commit 1b76101
Showing
4 changed files
with
393 additions
and
8 deletions.
There are no files selected for viewing
164 changes: 164 additions & 0 deletions
164
sdks/python/apache_beam/examples/inference/huggingface_question_answering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""""A pipeline that uses RunInference to perform Question Answering using the | ||
model from Hugging Face Models Hub. | ||
This pipeline takes questions and context from a custom text file separated by | ||
a semicolon. These are converted to SquadExamples by using the utility provided | ||
by transformers.QuestionAnsweringPipeline and passed to the model handler. | ||
We just provide the model name here because the model repository specifies the | ||
task that it will do. The pipeline then writes the prediction to an output | ||
file in which users can then compare against the original context. | ||
""" | ||
|
||
import argparse | ||
import logging | ||
from typing import Iterable | ||
from typing import Tuple | ||
|
||
import apache_beam as beam | ||
from apache_beam.ml.inference.base import KeyedModelHandler | ||
from apache_beam.ml.inference.base import PredictionResult | ||
from apache_beam.ml.inference.base import RunInference | ||
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler | ||
from apache_beam.ml.inference.huggingface_inference import PipelineTask | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.options.pipeline_options import SetupOptions | ||
from apache_beam.runners.runner import PipelineResult | ||
from transformers import QuestionAnsweringPipeline | ||
|
||
|
||
class PostProcessor(beam.DoFn): | ||
"""Processes the PredictionResult to get the predicted answer. | ||
Hugging Face Pipeline for Question Answering returns a dictionary | ||
with score, start and end index of answer and the answer. | ||
""" | ||
def process(self, result: Tuple[str, PredictionResult]) -> Iterable[str]: | ||
text, prediction = result | ||
predicted_answer = prediction.inference['answer'] | ||
yield text + ';' + predicted_answer | ||
|
||
|
||
def preprocess(text): | ||
""" | ||
preprocess separates the text into question and context | ||
by splitting on semi-colon. | ||
Args: | ||
text (str): string with question and context separated by semi-colon. | ||
Yields: | ||
(str, str): yields question and context from text. | ||
""" | ||
if len(text.strip()) > 0: | ||
question, context = text.split(';') | ||
yield (question, context) | ||
|
||
|
||
def create_squad_example(text): | ||
"""Creates SquadExample objects to be fed to QuestionAnsweringPipeline | ||
supported by Hugging Face. | ||
Check out https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.QuestionAnsweringPipeline.__call__.X #pylint: disable=line-too-long | ||
to learn about valid input types for QuestionAnswering Pipeline. | ||
Args: | ||
text (Tuple[str,str]): a tuple of question and context. | ||
""" | ||
question, context = text | ||
yield question, QuestionAnsweringPipeline.create_sample(question, context) | ||
|
||
|
||
def parse_known_args(argv): | ||
"""Parses args for the workflow.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--input', | ||
dest='input', | ||
help='Path of file containing question and context separated by semicolon' | ||
) | ||
parser.add_argument( | ||
'--output', | ||
dest='output', | ||
required=True, | ||
help='Path of file in which to save the output predictions.') | ||
parser.add_argument( | ||
'--model_name', | ||
dest='model_name', | ||
default="deepset/roberta-base-squad2", | ||
help='Model repository-id from Hugging Face Models Hub.') | ||
parser.add_argument( | ||
'--revision', | ||
dest='revision', | ||
help= | ||
'Specific model version to use - branch name, tag name, or a commit-id.') | ||
return parser.parse_known_args(argv) | ||
|
||
|
||
def run( | ||
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: | ||
""" | ||
Args: | ||
argv: Command line arguments defined for this example. | ||
save_main_session: Used for internal testing. | ||
test_pipeline: Used for internal testing. | ||
""" | ||
known_args, pipeline_args = parse_known_args(argv) | ||
pipeline_options = PipelineOptions(pipeline_args) | ||
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session | ||
|
||
pipeline = test_pipeline | ||
if not test_pipeline: | ||
pipeline = beam.Pipeline(options=pipeline_options) | ||
|
||
model_handler = HuggingFacePipelineModelHandler( | ||
task=PipelineTask.QuestionAnswering, | ||
model=known_args.model_name, | ||
load_model_args={ | ||
'framework': 'pt', 'revision': known_args.revision | ||
}) | ||
if not known_args.input: | ||
text = ( | ||
pipeline | 'CreateSentences' >> beam.Create([ | ||
"What does Apache Beam do?;" | ||
"Apache Beam enables batch and streaming data processing.", | ||
"What is the capital of France?;The capital of France is Paris .", | ||
"Where was beam summit?;Apache Beam Summit 2023 was in NYC.", | ||
])) | ||
else: | ||
text = ( | ||
pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input)) | ||
processed_text = ( | ||
text | ||
| 'PreProcess' >> beam.ParDo(preprocess) | ||
| 'SquadExample' >> beam.ParDo(create_squad_example)) | ||
output = ( | ||
processed_text | ||
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler)) | ||
| 'ProcessOutput' >> beam.ParDo(PostProcessor())) | ||
_ = output | "WriteOutput" >> beam.io.WriteToText( | ||
known_args.output, shard_name_template='', append_trailing_newlines=True) | ||
|
||
result = pipeline.run() | ||
result.wait_until_finish() | ||
return result | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.getLogger().setLevel(logging.INFO) | ||
run() |
Oops, something went wrong.