使用亚马逊 SageMaker 进行批量转换 Jumpstart Text2Text Generation 大型语言模型

今天我们很高兴地宣布,您现在可以使用 亚马逊 SageMaker JumpStart 大型语言模型 (LLM) 进行批量转换,用于 Text2Text 生成。批量转换在响应不必是实时的情况下很有用,因此您可以批量对大型数据集进行推断。对于批量转换,运行批处理作业,将批量输入作为数据集和预训练模型,并输出数据集中每个数据点的预测。批量转换具有成本效益,因为与具有永久硬件的实时托管端点不同,批处理转换集群会在作业完成后被拆除,因此硬件仅在批处理作业期间使用。

在某些用例中,实时推理请求可以分成小批量进行分组以进行批处理,以创建实时或近乎实时的响应。例如,如果您需要以低延迟和高吞吐量处理连续的数据流,则单独为每个请求调用实时终端节点将需要更多资源,并且可能需要更长的时间来处理所有请求,因为处理是按顺序进行的。更好的方法是对一些请求进行分组,然后在批量推理模式下调用实时端点,该端点会在模型的一次向前传递中处理您的请求,并实时或近乎实时地返回请求的批量响应。响应的延迟将取决于您将多少个请求组合在一起以及实例内存大小,因此您可以根据延迟和吞吐量的业务需求调整批次大小。我们之所以称之为 实时批量推断, 是因为它结合了批处理的概念,同时仍然提供实时响应。通过实时批量推断,您可以在低延迟和高吞吐量之间取得平衡,从而使您能够及时、高效地处理大量数据。

Text2Text 生成模型的 Jumpstart 批量转换允许您通过环境变量传递批处理超参数,从而进一步提高吞吐量并最大限度地减少延迟。

JumpStart 为各种问题类型提供预训练的开源模型,以帮助您开始使用机器学习 (ML)。在部署之前,您可以逐步训练和调整这些模型。JumpStart 还提供用于为常见用例设置基础设施的解决方案模板,以及使用 Am azon SageMaker 进行机器学习的可执行笔记本示例。您可以通过亚马逊 S ageMaker Studio 中的 JumpStart 登录页面访问预先训练的模型、解决方案模板和示例。 你也可以使用 SageMaker Python SDK 访问 JumpStart 模型。

在这篇文章中,我们演示了如何使用来 自 Hugging Face 的最先进的预训练 text2text FLAN T5 模型 进行批量转换和实时批量推断。

解决方案概述

该笔记本显示了从 Hugging Face 中批量转换的预训练 Text2Text FLAN T5 模型,可在以下 GitHub 存储库 中找到。 本笔记本使用 Hugging Face cnn_dailymail 数据集中的数据使用 SageMak er SDK 执行文本摘要任务。

以下是实现批量转换和实时批量推断的关键步骤:

  1. 设置先决条件。
  2. 选择预训练模型。
  3. 检索模型的工件。
  4. 指定批量转换作业超参数。
  5. 为批量转换准备数据。
  6. 运行批量转换作业。
  7. 使用 RO UGE (以回忆为导向的 Gisting 评估底层研究)分数评估摘要。
  8. 执行实时批量推断。

设置先决条件

在运行笔记本之前,必须完成一些初始设置步骤。让我们设置 SageMaker 执行角色,使其有权代表你运行 亚马逊云科技 服务:

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

选择预训练的模型

我们使用 huggingface-text2text-flan-t5-large 模型作为默认模型。或者,您可以在 JumpStart 上检索可用的 Text2Text 模型列表,然后选择您的首选模型。此方法提供了一种使用同一个笔记本电脑选择不同型号 ID 的简单方法。出于演示目的,我们使用了 huggingface-text2text-flan-t5-large 模型:

model_id, model_version, = (
"huggingface-text2text-flan-t5-large",
"*",
)

检索模型的工件

使用 SageMaker,我们可以对预训练的模型进行推断,即使无需先在新数据集上对其进行微调。我们首先检索预训练模型的 deplo y_image_uri 、deploy_s ource_uri 和 model_ uri:

inference_instance_type = "ml.p3.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
region=None,
framework=None, # automatically inferred from model_id
image_scope="inference",
model_id=model_id,
model_version=model_version,
instance_type=inference_instance_type,
)

# Retrieve the model uri.
model_uri = model_uris.retrieve(
model_id=model_id, model_version=model_version, model_scope="inference"
)

#Create the SageMaker model instance
model = Model(
image_uri=deploy_image_uri,
model_data=model_uri,
role=aws_role,
predictor_cls=Predictor)

指定批量转换作业超参数

您可以将任何超参数子集作为环境变量传递给批处理转换作业。您也可以在 JSON 负载中传递这些超参数。但是,如果您为超参数设置环境变量,如以下代码所示,则不会使用 JSON 行负载中各个示例中的高级超参数。如果你想使用来自负载的超参数,你可能需要改为将 hyp er_params_dict 参数设置为空。

#Specify the Batch Job Hyper Params Here, If you want to treate each example hyperparameters different please pass hyper_params_dict as None
hyper_params = {"batch_size":4, "max_length":50, "top_k": 50, "top_p": 0.95, "do_sample": True}
hyper_params_dict = {"HYPER_PARAMS":str(hyper_params)}

为批量转换准备数据

现在我们已经准备好加载来自 Hugging Face 的 cnn_dailymail 数据集 了:

cnn_test = load_dataset('cnn_dailymail','3.0.0',split='test')

我们仔细检查每个数据条目并以所需格式创建输入数据。我们创建一个 ar ticles.jsonl 文件作为测试数据文件,其中包含需要汇总为输入负载的文章。创建此文件时,我们会向每个测试输入行附加 “简要总结此文本:” 的提示。如果要为每个测试输入设置不同的超参数,则可以在创建数据集时附加这些超参数。

我们创建了 highlights.jsonl 作为基本真相文件,其中包含存储在测试文件 articles.jsonl 中的每篇文章的亮点。 我们将两个测试文件存储在 亚马逊简单存储服务 (Amazon S3)存储桶中。参见以下代码:

#You can specify a prompt here
prompt = "Briefly summarize this text: "
#Provide the test data and the ground truth file name
test_data_file_name = "articles.jsonl"
test_reference_file_name = 'highlights.jsonl'

test_articles = []
test_highlights =[]

# We will go over each data entry and create the data in the input required format as described above
for id, test_entry in enumerate(cnn_test):
    article = test_entry['article']
    highlights = test_entry['highlights']
    # Create a payload like this if you want to have different hyperparameters for each test input
    # payload = {"id": id,"text_inputs": f"{prompt}{article}", "max_length": 100, "temperature": 0.95}
    # Note that if you specify hyperparameter for each payload individually, you may want to ensure that hyper_params_dict is set to None instead
    payload = {"id": id,"text_inputs": f"{prompt}{article}"}
    test_articles.append(payload)
    test_highlights.append({"id":id, "highlights": highlights})

with open(test_data_file_name, "w") as outfile:
    for entry in test_articles:
        outfile.write("%s\n" % json.dumps(entry))

with open(test_reference_file_name, "w") as outfile:
    for entry in test_highlights:
        outfile.write("%s\n" % json.dumps(entry))

# Uploading the data        
s3 = boto3.client("s3")
s3.upload_file(test_data_file_name, output_bucket, os.path.join(output_prefix + "/batch_input/articles.jsonl"))

运行批量转换作业

当您启动批量转换作业时,SageMaker 会启动必要的计算资源来处理数据,包括取决于所选实例类型的 CPU 或 GPU 实例。在批量转换作业期间,SageMaker 会自动预置和管理处理数据所需的计算资源,包括实例、存储和网络资源。批量转换作业完成后,SageMaker 会自动清理计算资源。这意味着在作业期间使用的实例和存储将被停止和删除,从而释放资源并最大限度地降低成本。参见以下代码:

# Creating the Batch transformer object
batch_transformer = model.transformer(
    instance_count=1,
    instance_type=inference_instance_type,
    output_path=s3_output_data_path,
    assemble_with="Line",
    accept="text/csv",
    max_payload=1,
    env = hyper_params_dict
)

# Making the predications on the input data
batch_transformer.transform(s3_input_data_path, content_type="application/jsonlines", split_type="Line")

batch_transformer.wait()

以下是 art icles.jsonl 测试文件 中的一条示例记录。请注意,此文件中的记录的 ID 与 predict.jsonl 文件记录相匹配,该记录显示了 Hugging Face Text2Text 模型输出的摘要记录。同样,事实真相文件也有与数据记录相匹配的 ID。测试文件、基本真相文件和输出文件中的匹配ID允许将输入记录与输出记录关联起来,以便于解释结果。

以下是为汇总而提供的示例输入记录:

{"id": 0, "text_inputs": "Briefly summarize this text: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court's treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What's objectionable is the attempts to undermine international justice, not Palestine's decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court's decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report."}

以下是经过汇总的预测输出:

{'id': 0, 'generated_texts': ['The Palestinian Authority officially became a member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories.']}

以下是用于模型评估目的的基本真相摘要:

{"id": 0, "highlights": "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .\nIsrael and the United States opposed the move, which could open the door to war crimes investigations against Israelis ."}

接下来,我们使用基本真值和预测输出进行模型评估。

使用 ROUGE 分数评估模型¶

RO UGE ,即以回忆为导向的 Gisting Enderstudy,是一组指标和软件包,用于评估自然语言处理中的自动摘要和机器翻译。这些指标将自动生成的摘要或翻译与参考文献(人工制作的)摘要或翻译或一组参考文献进行比较。

在以下代码中,我们将预测的摘要和原始摘要合并到通用密钥 ID 上, 并使用它来计算 ROUGE 分数:

# Downloading the predictions
s3.download_file(
output_bucket, output_prefix + "/batch_output/" + "articles.jsonl.out", "predict.jsonl"
)

with open('predict.jsonl', 'r') as json_file:
json_list = list(json_file)

# Creating the prediction list for the dataframe
predict_dict_list = []
for predict in json_list:
if len(predict) > 1:
predict_dict = ast.literal_eval(predict)
predict_dict_req = {"id": predict_dict["id"], "prediction": predict_dict["generated_texts"][0]}
predict_dict_list.append(predict_dict_req)

# Creating the predictions dataframe
predict_df = pd.DataFrame(predict_dict_list)

test_highlights_df = pd.DataFrame(test_highlights)

# Combining the predict dataframe with the original summarization on id to compute the rouge score
df_merge = test_highlights_df.merge(predict_df, on="id", how="left")

rouge = evaluate.load('rouge')
results = rouge.compute(predictions=list(df_merge["prediction"]),references=list(df_merge["highlights"]))
print(results)
{'rouge1': 0.32749078992945646, 'rouge2': 0.126038645005132, 'rougeL': 0.22764277967933363, 'rougeLsum': 0.28162915746368966}

执行实时批量推断

接下来,我们将向您展示如何通过以列表形式提供输入来在端点上运行实时批量推断。我们使用与之前相同的模型 ID 和数据集,只是我们从测试数据集中提取了一些记录并使用它们来调用实时端点。

以下代码显示如何创建和部署实时端点以进行实时批量推理:

from sagemaker.utils import name_from_base
endpoint_name = name_from_base(f"jumpstart-example-{model_id}")
# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name
)

接下来,我们准备好输入有效载荷。为此,我们使用之前准备的数据,提取前 10 个测试输入,并在文本输入中附加我们想要使用的超参数。我们将此负载提供给实时 invoke_endpoint 然后,响应负载以响应列表的形式返回。参见以下代码:

#Provide all the text inputs to the model as a list
text_inputs = [entry["text_inputs"] for entry in test_articles[0:10]]

# The information about the different Parameters is provided above
payload = {
"text_inputs": text_inputs,
"max_length": 50,
"num_return_sequences": 1,
"top_k": 50,
"top_p": 0.95,
"do_sample": True,
"batch_size": 4
}


def query_endpoint_with_json_payload(encoded_json, endpoint_name):
client = boto3.client("runtime.sagemaker")
response = client.invoke_endpoint(
EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
)
return response


query_response = query_endpoint_with_json_payload(
json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
)


def parse_response_multiple_texts(query_response):
model_predictions = json.loads(query_response["Body"].read())
return model_predictions

generated_text_list = parse_response_multiple_texts(query_response)
print(*generated_text_list, sep='\n')

清理

测试完端点后,请务必删除 SageMaker 推理端点并删除模型以避免产生费用。

结论

在这本笔记本中,我们进行了批量转换,展示了用于摘要任务的 Hugging Face Text2Text 生成器模型。批量转换有利于在无需永久端点的情况下从大型数据集中获得推论。我们将输入记录与推论联系起来,以帮助解释结果。我们使用 ROUGE 分数将测试数据汇总与模型生成的汇总进行了比较。

此外,我们还演示了实时批量推断,在流式输入数据等场景中,您可以向实时端点发送一小批数据,以实现延迟和吞吐量之间的平衡。实时批量推断有助于提高实时请求的吞吐量。

立即在 SageMaker 中试用 Text2Text 生成模型进行批量转换,并告诉我们你的反馈!


作者简介

Hemant Singh 是一名机器学习工程师,拥有亚马逊 SageMaker JumpStart 和亚马逊 SageMaker 内置算法方面的经验。他拥有库兰特数学科学研究所的硕士学位和印度理工学院德里分校的理学学士学位。他在处理自然语言处理、计算机视觉和时间序列分析领域的各种机器学习问题方面拥有经验。

Rachna Chadha 是 亚马逊云科技 战略账户领域首席解决方案架构师 AI/ML。拉赫纳是一位乐观主义者,他相信以合乎道德和负责任的方式使用人工智能可以改善未来的社会,带来经济和社会繁荣。在业余时间,Rachna 喜欢与家人共度时光、远足和听音乐。

Ashish Khetan 博士 是一位高级应用科学家,拥有亚马逊 SageMaker 内置算法,并帮助开发机器学习算法。他在伊利诺伊大学厄巴纳-香槟分校获得博士学位。他是机器学习和统计推理领域的活跃研究人员,曾在Neurips、ICML、ICLR、JMLR、ACL和EMNLP会议上发表过许多论文。