We use machine learning technology to do auto-translation. Click "English" on top navigation bar to check Chinese version.
New performance improvements in Amazon SageMaker model parallel library
Foundation models are large deep learning models trained on a vast quantity of data at scale. They can be further fine-tuned to perform a variety of downstream tasks and form the core backbone of enabling several AI applications. The most prominent category is large-language models (LLM), including auto-regressive models such as GPT variants trained to complete natural text. LLMs typically contain billions of parameters, making them rarely fit on one single accelerator, and require model parallelism techniques. Another category is diffusion models, notably
To help our customers further minimize training costs and accelerate time-to-market, we are thrilled to introduce two new performance improvements in SageMaker model parallel — SMDDP Collectives and FlashAttention . SMDDP Collectives is the most performant collective library on Amazon Web Services infrastructure for large model training offered by
“Our mission at Stability AI is to build the foundation to activate humanity’s potential through AI. To achieve this mission, we need to efficiently train open-source foundation models on hundreds of accelerated compute instances. We rely on SageMaker and its distributed training libraries to optimize performance and implement state-of-the-art strategies to shard models and data across our training cluster. These optimizations reduce our training costs, help us meet customer needs faster, and speed up the development of new models.”
— Emad Mostaque, Founder and CEO of Stability AI.
In this blog post, we’ll first present our latest performance improvements in the SageMaker model parallel library. Then, we’ll revisit how to train foundational models using sharded data parallel. Finally, we’ll benchmark performance of 13B, 50B, and 100B parameter auto-regressive models and wrap up with future work.
New performance improvements in SageMaker model parallel library
Starting from
1. Amazon Web Services-optimized AllGather from SMDDP Collectives
In sharded data parallel, since only a shard of the model state is present on a GPU, an AllGather collective is needed to gather the full set of parameters from across all GPUs in the sharding group during forward or backward pass computations. In the previous versions of SageMaker model parallel, we used NVIDIA Collective Communications Library (NCCL) for these collectives. However, NCCL is a general purpose collective communications library not designed for Amazon Web Services infrastructure, which leads to sub-optimal performance even with EFA enabled.
Previously, we had developed the
2. FlashAttention
In modern transformer architecture, one of the largest sources of memory consumption is the activation footprint in the self-attention layer. This is because each attention head computes an SxS attention matrix for each input, where S is the sequence length, and this matrix goes through several operations, such as dropout, softmax, and matrix multiplication, with each intermediate output requiring memory space for use in back-propagation.
FlashAttention (
Train foundation models at scale with SageMaker model parallel
To train foundation models with SMP powered by SMDDP Collectives, there’s no additional changes required in your sharded data parallel training jobs. If you’re new to using sharded data parallel, follow this
train_gpt_simple.py
We highlight the key hyperparameters in the
ddp_dist_backend
in smp_options
now has a new option, "auto"
, as its default value. With "auto"
, SMP will use Amazon Web Services-optimized AllGather for sharded data parallelism jobs and fall back to NCCL otherwise. You can refer to
ddp_dist_backend"
to "nccl"
in smp_options
.
With the latest SMPv1.13 release, the sharded data parallel training technique supports FlashAttention for popular models including BERT, RoBERTa, GPT-2, GPT-J, GPT-Neo and GPT-NeoX out-of-the-box. This is enabled by passing tensor_parallelism=True
during
tensor_parallel_degree
. You can find an example in the same training script
train_gpt_simple.py
Benchmarking performance
We benchmarked sharded data parallelism in the SageMaker model parallel library on three different scales of models to understand how the two new features, FlashAttention and Amazon Web Services-optimized AllGather, contribute to performance improvement. Placement group is not required to reproduce these benchmarks on SageMaker.
13B parameter GPT-NeoX
In this setting, we focus on understanding the performance gain contributed by FlashAttention and we leave Amazon Web Services-optimized AllGather out of the picture. Using FlashAttention saves substantial GPU memory, which helps us increase batch size or reduce sharding degree, thereby improving performance. As the below results show, we observed an average of about 20.4% speedup in SMP with FlashAttention for 13B parameter GPT-NeoX model on various configurations across 16-64 p4d nodes. Memory usage during standard attention computation scales in a quadratic manner with an increase in sequence length, but FlashAttention has memory usage linear in sequence length. Hence FlashAttention is even more helpful as sequence length increases and makes it possible to use larger sequence lengths. Being memory-efficient without trading off model quality, FlashAttention has gained traction quickly in the large model training community in the past months including integration with
Configuration | Performance | ||||
Model/Training | Cluster | SMP | Without FlashAttention (TFLOPs/GPU) |
With FlashAttention (TFLOPs/GPU) |
% Speedup |
13B GPT-NeoX Seq length: 2048 Global batch size: 1024 FP16 |
16 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:64 gradient_accumulation: 1 |
130 | 159 | 22.31 |
13B GPT-NeoX Seq length: 2048 Global batch size: 2048 FP16 |
32 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:64 gradient_accumulation: 1 |
131 | 157 | 19.85 |
13B GPT-NeoX Seq length: 2048 Global batch size: 4096 FP16 |
64 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:64 gradient_accumulation: 1 |
131 | 156 | 19.08 |
50B parameter Bloom
Now, we look at how Amazon Web Services-optimized AllGather from SMDDP Collectives speedup large model training with SMP. We benchmark a 50B-parameter Bloom model and compare the performance with and without Amazon Web Services-optimized AllGather collective. We observe that SMDDP collectives speeds up model training by upto 40% across 32 nodes to 64 nodes training jobs. SMDDP collectives help achieve better performance due to better utilization of the 400 Gbps network bandwidth available with p4d.24xlarge instances. This coupled with the design choice to offload communication-related processing to the CPU, helps achieve good compute-to-network overlap leading to optimized performance. Compute-to-network overlap especially becomes important in large models since the size of data communicated across nodes scales linearly with an increase in the model size.
Configuration | Performance | ||||
Model/Training | Cluster | SMP | Without Amazon Web Services-optimized AllGather (TFLOPs/GPU) |
With Amazon Web Services-optimized AllGather (TFLOPs/GPU) |
% Speedup |
50B Bloom Seq length: 2048 Global batch size: 2048 BF16 |
32 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:128 gradient_accumulation: 1 |
102 | 143 | 40.20 |
50B Bloom Seq length: 2048 Global batch size: 4096 BF16 |
64 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:128 gradient_accumulation: 1 |
101 | 140 | 38.61 |
100B parameter GPT-NeoX
Finally, we benchmark SMP with both of the latest features enabled. It shows that this new release of SMP v1.13 is 30% faster than the previous version on a 100B-parameter GPT-NeoX model.
Configuration | Performance | ||||
Model/Training | Cluster | SMP | Without FlashAttention and without Amazon Web Services-optimized AllGather (TFLOPs/GPU) |
With FlashAttention + Amazon Web Services-optimized AllGather (TFLOPs/GPU) |
% Speedup |
100B GPT-NeoX Seq length: 2048 Global batch size: 2048 FP16 |
32 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:256 offload_activations
|
121 | 158 | 30.58 |
100B GPT-NeoX Seq length: 2048 Global batch size: 4096 FP16 |
64 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:256 offload_activations
|
122 | 158 | 29.51 |
For future work, we’ll be working on supporting an Amazon Web Services-optimized Reduce-Scatter in SMDDP Collectives. The Reduce-Scatter collective is critical in averaging and sharding gradients computed in the backward pass. We expect this to further speed up SMP library in the future releases.
Conclusion
In this post, we discuss the two latest performance improvements for sharded data parallel technique in SageMaker model parallel library. LLMs show great promise in improving the quality and re-usability of ML models. Amazon Web Services teams are working closely with customers to keep reducing their training costs and time-to-market. You can find more SageMaker model parallel examples in
About the authors
The mentioned AWS GenAI Services service names relating to generative AI are only available or previewed in the Global Regions. Amazon Web Services China promotes AWS GenAI Services relating to generative AI solely for China-to-global business purposes and/or advanced technology introduction.