✨ We just announced Composer to speed up training your models. Check us out on GitHub! ✨

New in Composer 0.11: FSDP Support, Streaming v0.1 Release, Simplified Checkpointing and Distributed Experience

New in Composer 0.11: FSDP Support, Streaming v0.1 Release, Simplified Checkpointing and Distributed Experience

We’re announcing the 0.11 release of Composer, MosaicML’s open-source library for training PyTorch neural networks faster, cheaper, and to higher accuracy. With Composer, we stack and combine speed-up methods into recipes that optimize your training. Composer 0.11 is available as a Python package via pip, and the source code is on GitHub.

We are excited to announce the release of Composer 0.11 (release notes)! This release includes several new features, plus improvements to existing capabilities - kudos to the Composer community for your engagement, feedback, and contributions to this release.

For those who want to join the Composer community:  learn more about contributing here, and message us on Slack if you have any questions or suggestions!

About Composer

The Composer library helps developers train PyTorch neural networks faster, at lower cost, and to higher accuracy. Composer includes:

  • 20+ methods for speeding up training networks for computer vision and language modeling.
  • An easy-to-use trainer that integrates best practices for efficient training.
  • Functional forms of all speedup methods that are easy to integrate into your existing training loop.
  • Strong and reproducible training baselines to get you started as quickly as possible.

FSDP support

With this Composer v0.11.0 release we are releasing beta support for PyTorch FSDP. PyTorch FSDP is a strategy for distributed training, similar to PyTorch DDP, that distributes work using data-parallelism only. On top of this, FSDP uses model, gradient, and optimizer sharding to dramatically reduce device memory requirements, and enables users to easily scale and train large models. We recently used PyTorch FSDP to simply and efficiently train LLMs with up to 70B parameters.

Here’s how easy it is to use PyTorch FSDP with Composer:

import torch.nn as nn
from composer import Trainer

class Block (nn.Module):
    ...

# Your custom model
class Model(nn.Module):
    def __init__(self, n_layers):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(...) for _ in range(n_layers)
        ]),
        self.head = nn.Linear(...)
    def forward(self, inputs):
        ...

    # FSDP Wrap Function
    def fsdp_wrap_fn(self, module):
        return isinstance(module, Block)

    # Activation Checkpointing Function
    def activation_checkpointing_fn(self, module):
        return isinstance(module, Block)

# ComposerModel wrapper, used by the Trainer
# to compute loss, metrics, etc.
class MyComposerModel(ComposerModel):

    def __init__(self, n_layers):
        super().__init__()
        self.model = Model(n_layers)
        ...

    def forward(self, batch):
        ...

    def eval_forward(self, batch, outputs=None):
        ...

    def loss(self, outputs, batch):
        ...
		    
# Pass your ComposerModel and fsdp_config into the Trainer
composer_model = MyComposerModel(n_layers=3)
fsdp_config = {
    'sharding_strategy': 'FULL_SHARD',
    'min_params': 1e8,
    'cpu_offload': False, # Not supported yet
    'mixed_precision': 'DEFAULT',
    'backward_prefetch': 'BACKWARD_POST',
    'activation_checkpointing': False,
    'activation_cpu_offload': False,
    'verbose': True
}

trainer = Trainer(
    model=composer_model,
    fsdp_config=fsdp_config,
    ...
)

trainer.fit()

Streaming v0.1 Release

We are spinning off a new Streaming dataset repository and it will be supported by default with this Composer release. Streaming datasets is a high-performance drop-in for Torch IterableDataset, enabling users to stream training data from cloud based object stores. Streaming is shipping with built-in support for popular open source datasets (ADE20K, C4, COCO, Enwiki, ImageNet, etc.)

Streaming datasets supports following features:

  • Unified data encoding and decoding interface for data conversion into streaming format
  • Prefetch target number of samples instead of greedily downloading the entire dataset.
  • Random access of samples, lazily loading shards to reduce data loading time
  • Fetch sample using Python indexing to view the sample locally
  • Data compression support to enable quick download times and lower cloud egress fees. Supported formats include: brotli, gzip, snappy, zstd, and bz2
  • Dataset hashing support to ensures data integrity through cryptographic and non-cryptographic hashing algorithm. Supported algorithms include: SHA2, SHA3, MD5, xxHash, etc.

To get started, install the Streaming PyPi package:

pip install mosaicml-streaming

You can use the streaming Dataset class with the PyTorch native DataLoader class as follows:
import torch
from streaming import Dataset

dataloader = torch.utils.data.DataLoader(dataset=Dataset(remote='s3://...'))

Custom datasets can be used by wrapping streaming.Dataset, and the underlying Dataset class will handle key concerns such as sharding across worker processes, de-duplication of training samples, data compression, and data integrity. Below is one of such example:
import torch
from streaming.base import Dataset

# Extending `streaming.Dataset` with custom get functionality
class CustomDataset(Dataset):
   def __init__(self, local, remote):
       super().__init__(local, remote)

   def __getitem__(self, idx: int) -> Any:
       obj = super().__getitem__(idx)
       return obj['x'], obj['y']

# Local caching directory
local = '/tmp/cache'

# Remote location to stream from
remote ='s3://mybucket/myfolder'

dataloader = torch.utils.data.DataLoader(dataset=CustomDataset(local=local, remote=remote))

Usability Enhancements

Simplified Checkpointing Interface

Composer supports saving and loading checkpoints locally and remotely to S3 or WandB, as well as auto-resuming with those checkpoints.  With this release we’ve greatly simplified configuration of loading and saving checkpoints in Composer.

To save checkpoints to S3, now all you need to do is:

  • Specify with save_folder your full URI to your save directory destination (e.g. 's3://my-bucket/{run_name}/checkpoints')
  • Optionally, set save_filename to the pattern you want for your checkpoint file names
from composer.trainer import Trainer

# Checkpoint saving to S3.
trainer = Trainer(
    model=model,
    save_folder="s3://my-bucket/{run_name}/checkpoints",
    run_name='my-run',
    save_interval="1ep",
    save_filename="ep{epoch}.pt",
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
    ...
)

trainer.fit()

Likewise, to load checkpoints from S3, all you have to do is:

  • Set load_path to the full URI to your desired checkpoint file (e.g. 's3://my-bucket/my-run/checkpoints/epoch13.pt')
from composer.trainer import Trainer

# Checkpoint loading from S3.
new_trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="10ep",
    load_path="s3://my-bucket/my-run/checkpoints/ep13.pt",
    ...
 )

new_trainer.fit()

Improved Distributed Experience

Composer performs a lot of the heavy lifting to setup distributed to ensure strong linear scaling of model training as an increasing number of GPUs are used, as discussed in this blog. While the distributed API has always been incredibly expressive and powerful, we’re now making the Composer API more usable for writing your own custom distributed entry points.

For example, say we’re creating a Python script which requires downloading a dataset. To avoid race conditions where different ranks try to write the dataset to the same place, we need to ensure that only rank 0 downloads the dataset first.  Previously, users needed to handle the distributed communication directly.  Now, we’ve exposed our distributed API so you can leverage all of our helpful functions and contexts.

import datetime
from composer.trainer.devices import DeviceGPU
from composer.utils import dist

dist.initialize(DeviceGPU(), datetime.timedelta(seconds=30)) # Initialize distributed module

if dist.get_local_rank() == 0: # Download dataset on rank zero
    dataset = download_my_dataset()
dist.barrier() # All ranks wait until dataset is downloaded

# Create and train your model!

Learn more!

Thanks for reading! If you'd like to learn more about Composer and to be part of the community you are welcomed to download Composer and try it out for your training tasks. As you try it out, come be a part of our community by engaging with us on Twitter, joining our Slack channel, or just giving us a star on Github.

What’s a Rich Text element?

The rich text element allows you to create and format headings, paragraphs, blockquotes, images, and video all in one place instead of having to add and format them individually. Just double-click and easily create content.

Static and dynamic content editing

A rich text element can be used with static or dynamic content. For static content, just drop it into any page and begin editing. For dynamic content, add a rich text field to any collection and then connect a rich text element to that field in the settings panel. Voila!

How to customize formatting for each rich text

Headings, paragraphs, blockquotes, figures, images, and figure captions can all be styled after a class is added to the rich text element using the "When inside of" nested selector system.