From e62345e9613da311a4c63f9b65cf548e9dff4f05 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 12 Dec 2022 20:29:43 +0300 Subject: [PATCH 01/24] bump versions --- setup.cfg | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 230f4ebfc..43682256b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,11 +33,11 @@ python_requires = >=3.7 install_requires = torch>=1.12 bitsandbytes==0.34.0 - accelerate==0.10.0 + accelerate==0.15.0 huggingface-hub==0.7.0 - transformers==4.21.3 + transformers==4.25.3 protobuf>=3.20.3,<4.0dev - hivemind==1.1.3 + https://github.com/learning-at-home/hivemind/archive/4c9c477e674f4ae40c5e7b9bc82056697720245c.zip humanfriendly async-timeout>=4.0.2 From 88348cfcf8cd06ad480c50273a8afbab94c5b218 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 12 Dec 2022 22:08:37 +0300 Subject: [PATCH 02/24] bump versions --- setup.cfg | 2 +- src/petals/bloom/block.py | 273 ++++---------------------- src/petals/bloom/from_pretrained.py | 27 ++- src/petals/bloom/model.py | 1 - src/petals/bloom/ops.py | 242 ----------------------- src/petals/cli/convert_model.py | 2 +- src/petals/cli/inference_one_block.py | 2 +- src/petals/client/remote_model.py | 6 +- src/petals/server/backend.py | 40 ++-- src/petals/server/handler.py | 6 +- src/petals/server/throughput.py | 2 +- tests/test_full_model.py | 4 +- 12 files changed, 82 insertions(+), 525 deletions(-) delete mode 100644 src/petals/bloom/ops.py diff --git a/setup.cfg b/setup.cfg index 43682256b..41d36d2f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ install_requires = torch>=1.12 bitsandbytes==0.34.0 accelerate==0.15.0 - huggingface-hub==0.7.0 + huggingface-hub==0.11.1 transformers==4.25.3 protobuf>=3.20.3,<4.0dev https://github.com/learning-at-home/hivemind/archive/4c9c477e674f4ae40c5e7b9bc82056697720245c.zip diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py index 857bf3b16..4f0735303 100644 --- a/src/petals/bloom/block.py +++ b/src/petals/bloom/block.py @@ -3,253 +3,52 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b See commit history for authorship. """ -import math +from typing import Optional, Tuple -import torch -import torch.nn as nn import torch.nn.quantized.dynamic.modules.linear +from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor -from petals.bloom.ops import ( - BloomGelu, - BloomScaledSoftmax, - attention_mask_func, - build_alibi_tensor, - dropout_add, - pre_process_alibi_for_pad, - split_tensor_along_last_dim, -) - - -class BloomAttention(nn.Module): - def __init__(self, config, layer_number=None): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_heads = config.n_head - self.head_dim = self.hidden_size // self.num_heads - self.split_size = self.hidden_size - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - self.masked_softmax_fusion = config.masked_softmax_fusion - self.hidden_dropout = config.hidden_dropout - - if self.head_dim * self.num_heads != self.hidden_size: - raise ValueError( - f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" - f" {self.num_heads})." - ) - - # Layer-wise attention scaling - self.layer_number = max(1, layer_number) - self.norm_factor = math.sqrt(self.head_dim) * self.layer_number - - # Scaled Softmax - self.scale_mask_softmax = BloomScaledSoftmax( - self.masked_softmax_fusion, - attention_mask_func, - self.attention_softmax_in_fp32, - self.layer_number, - ) - - self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) - self.dense = nn.Linear(self.hidden_size, self.hidden_size) - - self.attention_dropout = nn.Dropout(config.attention_dropout) +class WrappedBloomBlock(BloomBlock): def forward( self, - hidden_states, - residual, - layer_past=None, - attention_mask=None, - alibi=None, - head_mask=None, - use_cache=False, - output_attentions=False, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs ): + assert attention_mask is None + batch_size, seq_length = hidden_states.shape[:2] + past_length = 0 if layer_past is None else layer_past[0].shape[-1] + seq_length_with_past = seq_length + past_length + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: - current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1]) - alibi = build_alibi_tensor( - current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device - ) - - # hidden_states: [batch_size, seq_length, hidden_size] - # apply preprocessing if the input is padded - if attention_mask is not None: - alibi = pre_process_alibi_for_pad(alibi, attention_mask) - # otherwise repeat alibi tensor with the batch size - else: - alibi = alibi.repeat(hidden_states.shape[0], 1, 1) - - mixed_x_layer = self.query_key_value(hidden_states) - - # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim] - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - if layer_past is not None: - past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None - - # [batch_size, head_dim, q_length, k_length] - output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) - - # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] - query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1) - - # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] - key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1) - - # Raw attention scores. [batch_size * num_heads, q_length, k_length] - beta = 1.0 / self.layer_number - - matmul_result = torch.baddbmm( - alibi, - query_layer.transpose(1, 0), - key_layer.transpose(1, 0).transpose(1, 2), - beta=beta, - alpha=(1.0 / self.norm_factor), + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + return super().forward( + hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) - # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(*output_size) - - # attention scores and attention mask [b, np, sq, sk] - max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2]) - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype) - attention_probs = self.attention_dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # context layer shape: [batch_size, num_heads, q_length, head_dim] - output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [k_length, batch_size x num_heads, head_dim] - value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1) - - # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1)) - - # change view [batch_size, num_heads, q_length, head_dim] - context_layer = context_layer.view(*output_size) - - # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) - - context_layer = context_layer.view(*new_context_layer_shape) - - # Output. [q_length, batch_size, hidden_size] - - # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 - output_tensor = self.dense(context_layer) - output = output_tensor.transpose(1, 0) - - output = dropout_add(output, residual, self.hidden_dropout, self.training) - - outputs = (output, present) - if output_attentions: - outputs += (attention_probs,) - - return outputs - - -class BloomMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size) - self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size) - self.hidden_dropout = config.hidden_dropout - self.gelu_impl = BloomGelu() - - def forward(self, hidden_states, residual): - hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) - intermediate_output = self.dense_4h_to_h(hidden_states) - output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) - return output - - -class BloomBlock(nn.Module): - def __init__(self, config, layer_number=None): - super().__init__() - self.hidden_size = config.hidden_size - - self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon) - self.n_head = config.n_head - self.self_attention = BloomAttention(config, layer_number=layer_number) - self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon) - - self.mlp = BloomMLP(config) - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - self.hidden_dropout = config.hidden_dropout - - def forward( - self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - use_cache=False, - output_attentions=False, - alibi=None, - ): - # hidden_states: [batch_size, seq_length, hidden_size] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - - # Layer norm post the self attention. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states + def _prepare_attn_mask( + self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int + ) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = _make_causal_mask( + torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length + ) - # Self attention. - attn_outputs = self.self_attention( - layernorm_output, - residual, - layer_past=layer_past, - attention_mask=attention_mask, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask ) - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - layernorm_output = self.post_attention_layernorm(attention_output) - - # Get residual - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = attention_output - - # MLP. - output = self.mlp(layernorm_output, residual) - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions + return combined_attention_mask diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 16a4e7240..fb4555170 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -13,9 +13,10 @@ import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler from transformers.modeling_utils import WEIGHTS_NAME -from transformers.utils.hub import cached_path, hf_bucket_url +from transformers.models.bloom.configuration_bloom import BloomConfig +from transformers.utils import get_file_from_repo -from petals.bloom import BloomBlock, BloomConfig +from petals.bloom.block import WrappedBloomBlock from petals.utils.disk_cache import DEFAULT_CACHE_DIR use_hivemind_log_handler("in_root_logger") @@ -36,7 +37,7 @@ def load_pretrained_block( torch_dtype: Union[torch.dtype, str] = "auto", use_auth_token: Optional[str] = None, cache_dir: Optional[str] = None, -) -> BloomBlock: +) -> WrappedBloomBlock: """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it.""" if config is None: @@ -44,7 +45,7 @@ def load_pretrained_block( if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR - block = BloomBlock(config, layer_number=block_index) + block = WrappedBloomBlock(config) state_dict = _load_state_dict( converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir ) @@ -70,20 +71,14 @@ def _load_state_dict( cache_dir: Optional[str] = None, ) -> OrderedDict[str, torch.Tensor]: revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH - archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None) - - # Load from URL or cache if already cached - resolved_archive_file = cached_path( - archive_file, - cache_dir=cache_dir, - force_download=FORCE_DOWNLOAD, - proxies=None, - resume_download=RESUME_DOWNLOAD, - local_files_only=LOCAL_FILES_ONLY, + archive_file = get_file_from_repo( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME, + revision=revision, use_auth_token=use_auth_token, - user_agent=USER_AGENT, + cache_dir=cache_dir, ) - state_dict = torch.load(resolved_archive_file, map_location="cpu") + state_dict = torch.load(archive_file, map_location="cpu") return state_dict diff --git a/src/petals/bloom/model.py b/src/petals/bloom/model.py index 687d7651c..88084b520 100644 --- a/src/petals/bloom/model.py +++ b/src/petals/bloom/model.py @@ -213,7 +213,6 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - # Compute alibi tensor: check build_alibi_tensor documentation current_sequence_length = hidden_states.shape[1] if past_key_values and past_key_values[0]: current_sequence_length += past_key_values[0][0].shape[1] diff --git a/src/petals/bloom/ops.py b/src/petals/bloom/ops.py deleted file mode 100644 index 4df872e6a..000000000 --- a/src/petals/bloom/ops.py +++ /dev/null @@ -1,242 +0,0 @@ -""" -Utility operations used in the the BLOOM model -Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b -See commit history for authorship. -""" -import math - -import torch -import torch.autograd -import torch.nn.functional as F -from torch import nn - - -def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): - """Split a tensor along its last dimension. - - Args: - tensor: ([`torch.tensor`], *required*): - input tensor to split - num_partitions ([`int`], *required*): - number of partitions to split the tensor - contiguous_split_chunks ([`bool`], *optional*, default=`False`):: - If True, make each chunk contiguous in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - numerator, denominator = tensor.size()[last_dim], num_partitions - if not (numerator % denominator == 0): - raise ValueError(f"{numerator} is not divisible by {denominator}") - last_dim_size = numerator // denominator - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -def attention_mask_func(attention_scores, attention_mask, causal_mask): - if attention_mask.dtype == torch.bool: - attention_mask_bool = ~attention_mask - else: - attention_mask_bool = (1 - attention_mask).bool() - - query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1) - padded_causal_mask = ( - attention_mask_bool[:, None, key_length - query_length : key_length, None] - + ~causal_mask[:, :, key_length - query_length : key_length, :key_length] - ).bool() - padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool() - # Make use of floats - return ( - attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0), - padded_causal_mask, - ) - - -def build_alibi_tensor( - max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu") -) -> torch.Tensor: - """ - Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it - relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value - `softmax(l+a) = softmax(l)`. Based on - https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - Args: - Returns tensor shaped (n_head, 1, max_seq_len) - max_seq_len: (`int`, *required*): - max sequence length - n_head: (`int`, *required*): - number of heads - dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): - dtype of the output tensor - device: (`torch.device`, *optional*, default=`torch.device('cpu')`): - device of the output alibi tensor - """ - closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) - powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != n_head: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 - ) - num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - - lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32) - return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype) - - -def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor): - """ - Args: - Pre-process the alibi tensor for padding. - alibi: ([`torch.tensor`], *required*): - alibi tensor to pre-process - attention_mask: ([`torch.tensor`], *required*): - attention mask to pre-process - """ - assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]" - unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) - # ^-- [batch, max_len], values correspond to element indices after removing padding - # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 - alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) - return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1) - - -def dropout_add(x, residual, prob, training): - """ - Dropout add function - - Args: - x (`torch.tensor`, *required*): - input tensor - residual (`torch.tensor`, *required*): - esidual tensor - prob (`float`, *required*): - dropout probability - training (`bool`, *required*): - training mode - """ - out = nn.functional.dropout(x, p=prob, training=training) - out = residual + out - return out - - -def bloom_gelu_forward(x): - """ - Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to - make the model jitable. - - Args: - x (`torch.tensor`, *required*): - input hidden states - """ - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - - -def bloom_gelu_back(g, x): - """ - gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) + - 0.3989423 * x * torch.exp(-0.5 * x * x) - - Args: - g (`torch.tensor`, *required*): - gradient output tensor - x (`torch.tensor`, *required*): - input tensor - """ - x = x[0] # x is a tuple of 1 element, needs to unpack it first - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff * g - - -class GeLUFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - ctx.save_for_backward(input) - return bloom_gelu_forward(input) - - @staticmethod - def backward(ctx, grad_output): - input = ctx.saved_tensors - tmp = bloom_gelu_back(grad_output, input) - return tmp - - -class BloomGelu(nn.Module): - """ - BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model - torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly - copied from Megatron-DeepSpeed code and adapted for our needs - - See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329 - - """ - - def __init__(self): - super().__init__() - - def forward(self, x): - if self.training: - return GeLUFunction.apply(x) - else: - return bloom_gelu_forward(x) - - -class BloomScaledSoftmax(nn.Module): - """ - fused operation: scaling + mask + softmax - - Args: - scaled_masked_softmax_fusion (`bool`, *required*): - flag to indicate user want to use softmax fusion - mask_func (`function`, *required*): - mask function to be applied. - softmax_in_fp32 (`bool`, *required*): - if true, softmax in performed at fp32 precision. - scale (`float`, *required*): - scaling factor used in input tensor scaling. - """ - - def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale): - super().__init__() - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise ValueError("softmax should be in fp32 when scaled") - - def forward(self, input, mask, max_positions): - input_dtype = input.dtype - input_in_16bit = input_dtype in [torch.float16, torch.bfloat16] - softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype - - if self.scale is not None: - input = input * self.scale - - if mask is None: - mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device) - - mask = mask.to(input.device) - causal_mask = ( - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) - .view(1, 1, max_positions, max_positions) - .to(input.device) - ) - mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) - probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) - - if input_in_16bit and self.softmax_in_fp32: - probs = probs.to(dtype=input_dtype) - - return probs diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index 1846dab5d..2678eea6c 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -8,8 +8,8 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from huggingface_hub import Repository from tqdm.auto import tqdm +from transformers.models.bloom.modeling_bloom import BloomModel -from petals.bloom import BloomModel from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH from petals.client import DistributedBloomConfig diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index a5e5a005d..a3a1ff2a6 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -3,10 +3,10 @@ import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tqdm.auto import trange +from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock from petals.bloom.model import BloomConfig -from petals.bloom.ops import build_alibi_tensor use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index d69dd4b4b..041822291 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -7,15 +7,15 @@ import torch.nn as nn from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions - -from petals.bloom.model import ( +from transformers.models.bloom import ( BloomConfig, BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel, - LMHead, ) + +from petals.bloom.model import LMHead from petals.client.remote_generation import RemoteGenerationMixin from petals.client.remote_sequential import RemoteSequential from petals.constants import PUBLIC_INITIAL_PEERS diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4fadda276..47b51b0ca 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -6,7 +6,7 @@ from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger -from petals.bloom.from_pretrained import BloomBlock +from petals.bloom.block import WrappedBloomBlock from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy @@ -16,11 +16,11 @@ class TransformerBackend(ModuleBackend): - """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward""" + """A wrapper for a bloom block that can process requests for bloom layer forward, backward and inference""" def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): super().__init__(*args, **kwargs) - assert isinstance(self.module, BloomBlock) + assert isinstance(self.module, WrappedBloomBlock) self.memory_cache = memory_cache for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" @@ -50,6 +50,7 @@ def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, ) def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: + num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim with torch.inference_mode(): attention_cache_handle = int(cache_metadata[0, 0].item()) prefix_length = int(cache_metadata[0, 1].item()) @@ -59,24 +60,31 @@ def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" with self.memory_cache.use_cache(attention_cache_handle) as cache: - assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5 + batch_size = cache.shape[1] + max_length = cache.numel() // (2 * batch_size * head_dim * num_heads) + assert isinstance(self.module, WrappedBloomBlock) and cache.shape[0] == 2 and cache.ndim == 3 if not is_dummy(hypo_ids): assert hypo_ids.shape[0] == cache.shape[1] cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids - layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] - logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}") - hidden_states, (new_k, new_v) = self.module.forward( - hidden_states, layer_past=layer_past, use_cache=True - ) + key_cache = cache[0].view(batch_size, num_heads, head_dim, max_length) + value_cache = cache[1].view(batch_size, num_heads, max_length, head_dim) - # todo remove these asserts once we pass all tests - new_length = new_v.shape[1] + key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] + value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim] + logger.debug( + f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}" + ) + hidden_states, (new_key, new_value) = self.module.forward( + hidden_states, layer_past=(key_past, value_past), use_cache=True + ) + new_length = new_key.shape[-1] assert new_length > prefix_length - assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0] - assert new_k.shape[1] == new_length and new_v.shape[1] == new_length - assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:] - cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length] - cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length] + assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0] + assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length + new_key = new_key.view(batch_size, num_heads, head_dim, -1) + new_value = new_value.view(batch_size, num_heads, -1, head_dim) + key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] + value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] return (hidden_states,) def get_pools(self) -> Sequence[PrioritizedTaskPool]: diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 1613c19be..50f621500 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -343,10 +343,8 @@ async def _allocate_caches( for backend in backends: num_heads = backend.module.self_attention.num_heads head_dim = backend.module.self_attention.head_dim - - descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype) - # [key_or_value, batch_size, max_length, num_heads, head_dim] - + descr = TensorDescriptor(size=(2, batch_size, num_heads * head_dim * max_length), dtype=backend.dtype) + # ^-- flattened batch-first tensor of both keys and values; based on BLOOM layer_past layout handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr))) total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8 diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 6d22bb513..f3d718fd5 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -9,10 +9,10 @@ import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock from petals.bloom.model import BloomConfig -from petals.bloom.ops import build_alibi_tensor from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_8bit import replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 0ce8652d0..eac2339a4 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -3,9 +3,9 @@ import transformers from hivemind import get_logger, use_hivemind_log_handler from test_utils import * -from transformers.generation_utils import BeamSearchScorer +from transformers.generation import BeamSearchScorer +from transformers.models.bloom import BloomForCausalLM -from petals.bloom.model import BloomForCausalLM from petals.client.remote_model import DistributedBloomForCausalLM use_hivemind_log_handler("in_root_logger") From 7018954d1996c4c607b8d02e0003b2b6eb85be0e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 12 Dec 2022 22:14:36 +0300 Subject: [PATCH 03/24] yeet models --- src/petals/bloom/__init__.py | 2 - src/petals/bloom/model.py | 594 -------------------------- src/petals/bloom/modeling_utils.py | 74 ++++ src/petals/cli/inference_one_block.py | 2 +- src/petals/client/remote_model.py | 13 +- src/petals/server/server.py | 2 +- src/petals/server/throughput.py | 2 +- 7 files changed, 89 insertions(+), 600 deletions(-) delete mode 100644 src/petals/bloom/model.py create mode 100644 src/petals/bloom/modeling_utils.py diff --git a/src/petals/bloom/__init__.py b/src/petals/bloom/__init__.py index f0139a9f6..e69de29bb 100644 --- a/src/petals/bloom/__init__.py +++ b/src/petals/bloom/__init__.py @@ -1,2 +0,0 @@ -from petals.bloom.block import BloomBlock -from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel diff --git a/src/petals/bloom/model.py b/src/petals/bloom/model.py deleted file mode 100644 index 88084b520..000000000 --- a/src/petals/bloom/model.py +++ /dev/null @@ -1,594 +0,0 @@ -""" -PyTorch BLOOM model that implements several memory-efficient modes. -Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b -See commit history for authorship. -""" -from typing import Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from hivemind import use_hivemind_log_handler -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss -from transformers.file_utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - SequenceClassifierOutputWithPast, -) -from transformers.models.bloom.configuration_bloom import BloomConfig -from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel -from transformers.utils import logging - -from petals.bloom.block import BloomBlock - -use_hivemind_log_handler("in_root_logger") -logger = logging.get_logger(__file__) - -_CHECKPOINT_FOR_DOC = "bigscience/Bloom" -_CONFIG_FOR_DOC = "BloomConfig" -_TOKENIZER_FOR_DOC = "BloomTokenizer" - - -BLOOM_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BLOOM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input - sequence tokens in the vocabulary. - - If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): - Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - - If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see - `past_key_values`). - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. -""" - - -class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel): - @classmethod - def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) - - from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( - "low_cpu_mem_usage(`bool`, *optional*)", - "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", - ) - - -@add_start_docstrings( - "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.", - BLOOM_START_DOCSTRING, -) -class BloomModel(_BloomPreTrainedModelWithModifiedDefaults): - def __init__(self, config): - super().__init__(config) - assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity" - - self.embed_dim = config.hidden_size - self.n_head = config.n_head - - # Embedding + LN Embedding - self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - - # Transformer blocks - self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)]) - - # Final Layer Norm - self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.word_embeddings - - def set_input_embeddings(self, new_embeddings): - self.word_embeddings = new_embeddings - - @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - if position_ids is not None: - logger.warning("position_ids are ignored in this bloom implementation") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_head x N x N - # head_mask has shape n_layer x batch x n_head x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - # Note: it supports only float32 or bfloat16 inputs - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - - output_shape = input_shape + (hidden_states.size(-1),) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - current_sequence_length = hidden_states.shape[1] - if past_key_values and past_key_values[0]: - current_sequence_length += past_key_values[0][0].shape[1] - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions, alibi=None) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=None, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = hidden_states.view(output_shape) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -@add_start_docstrings( - """ - The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - BLOOM_START_DOCSTRING, -) -class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults): - _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.transformer = BloomModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - } - - @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=CausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -@add_start_docstrings( - """ - The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input - embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. - In addition, it provides an effcient way to deal with half-precision word embeddings on CPU. - """, - BLOOM_START_DOCSTRING, -) -class LMHead(nn.Module): - def __init__(self, config, word_embeddings: nn.Embedding): - super().__init__() - self.word_embeddings = word_embeddings - self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu - - @property - def in_features(self) -> int: - return self.word_embeddings.num_embeddings - - @property - def out_features(self) -> int: - return self.word_embeddings.embedding_dim - - @property - def weight(self): - return self.word_embeddings.weight - - @property - def bias(self): - return None - - def forward(self, hidden_states): - word_embeddings = self.word_embeddings.weight - - # We use 'chunked_forward' only when embeddings are in half-precision on CPU. - if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu": - lm_logits = self.chunked_forward(hidden_states) - else: - # Switch dtype in case word_embeddings are fp16/bf16 - hidden_states = hidden_states.to(word_embeddings.dtype) - lm_logits = F.linear(hidden_states, word_embeddings) - return lm_logits - - def chunked_forward(self, hidden_states): - """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. - chunk_size: provides trade-off between efficiency and extra memory consumption. - """ - assert self.chunk_size > 0, "Chunk size for chunked forward must be positive" - - word_embeddings = self.word_embeddings.weight - num_embeddings = self.word_embeddings.num_embeddings - - hidden_states = hidden_states.float() - output = torch.zeros(*hidden_states.shape[:-1], num_embeddings) - - for i in range(0, num_embeddings, self.chunk_size): - chunk = word_embeddings[i : i + self.chunk_size].float() - output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk) - return output - - -@add_start_docstrings( - """ - The Bloom Model transformer with a sequence classification head on top (linear layer). - [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-1) do. - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - BLOOM_START_DOCSTRING, -) -class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults): - _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.transformer = BloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=SequenceClassifierOutputWithPast, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 - else: - sequence_lengths = -1 - logger.warning( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py new file mode 100644 index 000000000..08ca40dc2 --- /dev/null +++ b/src/petals/bloom/modeling_utils.py @@ -0,0 +1,74 @@ +""" +PyTorch BLOOM model that implements several memory-efficient modes. +Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b +See commit history for authorship. +""" + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from hivemind import use_hivemind_log_handler +from torch import nn +from transformers import BloomConfig +from transformers.utils import logging + +use_hivemind_log_handler("in_root_logger") +logger = logging.get_logger(__file__) + + +class LMHead(nn.Module): + """ + The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input + embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. + In addition, it provides an effcient way to deal with half-precision word embeddings on CPU. + """ + + def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding): + super().__init__() + self.word_embeddings = word_embeddings + self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu + + @property + def in_features(self) -> int: + return self.word_embeddings.num_embeddings + + @property + def out_features(self) -> int: + return self.word_embeddings.embedding_dim + + @property + def weight(self): + return self.word_embeddings.weight + + @property + def bias(self): + return None + + def forward(self, hidden_states): + word_embeddings = self.word_embeddings.weight + + # We use 'chunked_forward' only when embeddings are in half-precision on CPU. + if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu": + lm_logits = self.chunked_forward(hidden_states) + else: + # Switch dtype in case word_embeddings are fp16/bf16 + hidden_states = hidden_states.to(word_embeddings.dtype) + lm_logits = F.linear(hidden_states, word_embeddings) + return lm_logits + + def chunked_forward(self, hidden_states): + """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. + chunk_size: provides trade-off between efficiency and extra memory consumption. + """ + assert self.chunk_size > 0, "Chunk size for chunked forward must be positive" + + word_embeddings = self.word_embeddings.weight + num_embeddings = self.word_embeddings.num_embeddings + + hidden_states = hidden_states.float() + output = torch.zeros(*hidden_states.shape[:-1], num_embeddings) + + for i in range(0, num_embeddings, self.chunk_size): + chunk = word_embeddings[i : i + self.chunk_size].float() + output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk) + return output diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index a3a1ff2a6..dfe1f94a0 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -6,7 +6,7 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock -from petals.bloom.model import BloomConfig +from petals.bloom.modeling_utils import BloomConfig use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 041822291..8b0fd919d 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -15,7 +15,7 @@ BloomPreTrainedModel, ) -from petals.bloom.model import LMHead +from petals.bloom.modeling_utils import LMHead from petals.client.remote_generation import RemoteGenerationMixin from petals.client.remote_sequential import RemoteSequential from petals.constants import PUBLIC_INITIAL_PEERS @@ -191,6 +191,17 @@ def forward( attentions=None, ) + @classmethod + def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) + + from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( + "low_cpu_mem_usage(`bool`, *optional*)", + "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ) + class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 97b03e040..749d53b98 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -18,7 +18,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block -from petals.bloom.model import BloomConfig +from petals.bloom.modeling_utils import BloomConfig from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index f3d718fd5..bca119aff 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -12,7 +12,7 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock -from petals.bloom.model import BloomConfig +from petals.bloom.modeling_utils import BloomConfig from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_8bit import replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR From f46d421c645561c9c8a164a51a63c2f9d6279705 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 12 Dec 2022 22:18:33 +0300 Subject: [PATCH 04/24] y u no instal? --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 41d36d2f8..96d8ce2c9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,9 +35,9 @@ install_requires = bitsandbytes==0.34.0 accelerate==0.15.0 huggingface-hub==0.11.1 - transformers==4.25.3 + transformers==4.25.1 protobuf>=3.20.3,<4.0dev - https://github.com/learning-at-home/hivemind/archive/4c9c477e674f4ae40c5e7b9bc82056697720245c.zip + hivemind==1.1.3 humanfriendly async-timeout>=4.0.2 From 1e2ab6b6a8bf91182b60c7a0176463965fe63b8e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 12 Dec 2022 22:31:27 +0300 Subject: [PATCH 05/24] fix imports --- src/petals/server/block_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index ce5f678f5..43b193057 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -3,7 +3,8 @@ import torch from accelerate import init_empty_weights -from petals.bloom import BloomBlock, BloomConfig +from petals.bloom.block import WrappedBloomBlock +from petals.bloom.modeling_utils import BloomConfig def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]: @@ -22,7 +23,6 @@ def get_block_size( *, dtype: Optional[Union[str, torch.dtype]] = None, load_in_8bit: Optional[bool] = None, - layer_index: int = 0, eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc. ) -> int: if location == "memory": @@ -31,7 +31,7 @@ def get_block_size( ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' with init_empty_weights(): - block = BloomBlock(config, layer_index) + block = WrappedBloomBlock(config) n_params = sum(param.numel() for param in block.parameters()) if location == "memory" and load_in_8bit: From 4b35dd7d8c73ac5f404a51bab2e8ed990f08d6fb Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 12 Dec 2022 22:47:13 +0300 Subject: [PATCH 06/24] fix edge case where session crashes when receiving seq length 0 --- src/petals/server/handler.py | 3 +++ tests/test_full_model.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 50f621500..a7588bb85 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -146,6 +146,9 @@ async def rpc_inference( for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt + if hidden_states.numel() == 0: + continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user to pre-allocate cache or check that server *can* allocate that cache cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length assert isinstance( diff --git a/tests/test_full_model.py b/tests/test_full_model.py index eac2339a4..b185ea904 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -13,7 +13,8 @@ @pytest.mark.forked -def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): +@pytest.mark.parametrize("pass_empty_tensors", (True, False)) +def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 @@ -33,8 +34,15 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): embs = model.transformer.word_embeddings_layernorm(embs) recurrent_outputs = [] with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess: + if pass_empty_tensors: + recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) + for t in range(embs.shape[1]): recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) + if t == 5 and pass_empty_tensors: + recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) + recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) + recurrent_outputs = torch.cat(recurrent_outputs, dim=1) recurrent_outputs = model.transformer.ln_f(recurrent_outputs) recurrent_outputs = model.lm_head(recurrent_outputs) From 062cd51e75b9266ae4a295b416a7276a0cab8ca7 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Mon, 12 Dec 2022 23:53:12 +0300 Subject: [PATCH 07/24] review --- src/petals/cli/inference_one_block.py | 3 +-- src/petals/server/throughput.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index dfe1f94a0..01dacd8cd 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -31,7 +31,6 @@ def print_device_info(device=None): parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data") parser.add_argument("--config", required=True, type=str, help="Path to a config json file") parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict") - parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict") parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run") parser.add_argument("--device", default=None, type=str, help="Run inference on this device") args = parser.parse_args() @@ -40,7 +39,7 @@ def print_device_info(device=None): args.device = "cuda" if torch.cuda.is_available() else "cpu" config = BloomConfig.from_json_file(args.config) - block = BloomBlock(config, args.layer_index).to(args.device) + block = BloomBlock(config).to(args.device) cache = None diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index bca119aff..3a96ad393 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -115,10 +115,9 @@ def measure_compute_rps( load_in_8bit: bool, n_tokens: int = 16, n_steps: int = 500, - layer_index: int = 0, ) -> float: with torch.inference_mode(): - block = BloomBlock(config, layer_index).to(dtype) + block = BloomBlock(config).to(dtype) if load_in_8bit: block = replace_8bit_linear(block) block = block.to(device) From d227021a2e3325b3a92833b8df7e5707de7a5089 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 00:01:33 +0300 Subject: [PATCH 08/24] mixin --- src/petals/client/remote_model.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 8b0fd919d..5c70df80e 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -66,7 +66,20 @@ def force_non_empty_weights(): nn.Module.register_parameter = possibly_patched_register_parameter -class DistributedBloomModel(BloomModel): +class _LowCPUMemoryMixin: + @classmethod + def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) + + from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( + "low_cpu_mem_usage(`bool`, *optional*)", + "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", + ) + + +class DistributedBloomModel(BloomModel, _LowCPUMemoryMixin): """BloomModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ @@ -191,19 +204,8 @@ def forward( attentions=None, ) - @classmethod - def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs): - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) - - from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( - "low_cpu_mem_usage(`bool`, *optional*)", - "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", - ) - -class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM): +class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM, _LowCPUMemoryMixin): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = ( @@ -241,7 +243,7 @@ def set_output_embeddings(self, new_lm_head: nn.Linear): self.lm_head.bias[...] = new_lm_head.bias -class DistributedBloomForSequenceClassification(BloomForSequenceClassification): +class DistributedBloomForSequenceClassification(BloomForSequenceClassification, _LowCPUMemoryMixin): _keys_to_ignore_on_load_missing = ( BloomForSequenceClassification._keys_to_ignore_on_load_missing + DistributedBloomModel._keys_to_ignore_on_load_missing From ab813baa3e65ecdcfde82961cf83270892f5c7dc Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 00:03:59 +0300 Subject: [PATCH 09/24] remix --- src/petals/client/remote_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 5c70df80e..daaef83bb 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -79,7 +79,7 @@ def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwar ) -class DistributedBloomModel(BloomModel, _LowCPUMemoryMixin): +class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): """BloomModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [ @@ -205,7 +205,7 @@ def forward( ) -class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM, _LowCPUMemoryMixin): +class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = ( @@ -243,7 +243,7 @@ def set_output_embeddings(self, new_lm_head: nn.Linear): self.lm_head.bias[...] = new_lm_head.bias -class DistributedBloomForSequenceClassification(BloomForSequenceClassification, _LowCPUMemoryMixin): +class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification): _keys_to_ignore_on_load_missing = ( BloomForSequenceClassification._keys_to_ignore_on_load_missing + DistributedBloomModel._keys_to_ignore_on_load_missing From 9bf813b2f52eb8372b8c1aa40b66260001ab354a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 00:20:34 +0300 Subject: [PATCH 10/24] fix throughput --- src/petals/server/throughput.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 3a96ad393..7342edfaa 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -126,7 +126,8 @@ def measure_compute_rps( elapsed = 0 for step in range(n_steps + 1): dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) - alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype) + mask = torch.ones(n_tokens, step + 1, device=device, dtype=dtype) + alibi = build_alibi_tensor(mask, config.num_attention_heads, dtype=dtype) start_time = time.perf_counter() _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache) From 524468e125d205c80221a978b91e69a825a6f6f5 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 00:23:22 +0300 Subject: [PATCH 11/24] fix throughput --- src/petals/server/throughput.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 7342edfaa..0c8714b2a 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -9,9 +9,8 @@ import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler -from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from petals.bloom.block import BloomBlock +from petals.bloom.block import WrappedBloomBlock from petals.bloom.modeling_utils import BloomConfig from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_8bit import replace_8bit_linear @@ -117,7 +116,7 @@ def measure_compute_rps( n_steps: int = 500, ) -> float: with torch.inference_mode(): - block = BloomBlock(config).to(dtype) + block = WrappedBloomBlock(config).to(dtype) if load_in_8bit: block = replace_8bit_linear(block) block = block.to(device) @@ -126,11 +125,9 @@ def measure_compute_rps( elapsed = 0 for step in range(n_steps + 1): dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) - mask = torch.ones(n_tokens, step + 1, device=device, dtype=dtype) - alibi = build_alibi_tensor(mask, config.num_attention_heads, dtype=dtype) start_time = time.perf_counter() - _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache) + _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache) if step >= 1: # Skip the 1st step to exclude the initialization time elapsed += time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed From c4730125dfd6543e36aa73ab0918c2f98c2a60a0 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 00:42:22 +0300 Subject: [PATCH 12/24] benchmark throughput in CI jobs --- tests/test_aux_functions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/test_aux_functions.py diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py new file mode 100644 index 000000000..d1569d9c7 --- /dev/null +++ b/tests/test_aux_functions.py @@ -0,0 +1,13 @@ +import torch +from test_utils import * + +from petals.client import DistributedBloomConfig +from petals.server.throughput import measure_compute_rps + + +def test_throughput_basic(): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + throughput = measure_compute_rps( + config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10 + ) + assert isinstance(throughput, float) and throughput > 0 From f9e0910d4df2be8a051a376d8d6e34c945e4877d Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 13 Dec 2022 01:07:50 +0300 Subject: [PATCH 13/24] reduce ban timeout --- src/petals/client/routing/sequence_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 768b9620d..484e134fc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -57,7 +57,7 @@ def __init__( update_period: float = 30, request_timeout: float = 30, min_backoff: float = 1, - ban_timeout: float = 60, + ban_timeout: float = 15, sequence_info: Optional[RemoteSequenceInfo] = None, rpc_info: Optional[dict] = None, banned_peers: Optional[Blacklist] = None, From 044e9150837111c63105e64a23121ce8ff8b9228 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 01:12:49 +0300 Subject: [PATCH 14/24] fork pytest --- tests/test_aux_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index d1569d9c7..e807fc68e 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,3 +1,4 @@ +import pytest import torch from test_utils import * @@ -5,6 +6,7 @@ from petals.server.throughput import measure_compute_rps +@pytest.mark.forked def test_throughput_basic(): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) throughput = measure_compute_rps( From b12ad06932539e65a4d312f0ef4ab131112ddedc Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 13 Dec 2022 10:34:22 +0300 Subject: [PATCH 15/24] review --- tests/test_full_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index b185ea904..710ff332c 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -39,7 +39,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato for t in range(embs.shape[1]): recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) - if t == 5 and pass_empty_tensors: + if t == int(embs.shape[1] // 2) and pass_empty_tensors: recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) From 35e2c0a9899d9cb8fd1d6201851f332bad5df84b Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 13 Dec 2022 10:34:59 +0300 Subject: [PATCH 16/24] review --- src/petals/bloom/from_pretrained.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index fb4555170..9e039d61f 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -25,9 +25,6 @@ CLIENT_BRANCH = "main" BLOCK_BRANCH_PREFIX = "block_" USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False} -FORCE_DOWNLOAD = False -RESUME_DOWNLOAD = False -LOCAL_FILES_ONLY = False def load_pretrained_block( From 27ac5888efb21b5cf88b8d6d082e4c6c1a94ed6e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 10:35:19 +0300 Subject: [PATCH 17/24] Update tests/test_aux_functions.py Co-authored-by: Max Ryabinin --- tests/test_aux_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index e807fc68e..87da88d48 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,6 +1,6 @@ import pytest import torch -from test_utils import * +from test_utils import MODEL_NAME, INITIAL_PEERS from petals.client import DistributedBloomConfig from petals.server.throughput import measure_compute_rps From b090dd25ee3d00306bbda066d89ee693580ca417 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 13 Dec 2022 10:36:50 +0300 Subject: [PATCH 18/24] isort --- src/petals/cli/inference_one_block.py | 2 +- src/petals/server/block_utils.py | 2 +- src/petals/server/server.py | 2 +- src/petals/server/throughput.py | 2 +- tests/test_aux_functions.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index 01dacd8cd..336e2a352 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -3,10 +3,10 @@ import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tqdm.auto import trange +from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock -from petals.bloom.modeling_utils import BloomConfig use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index 43b193057..eca7143bf 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -2,9 +2,9 @@ import torch from accelerate import init_empty_weights +from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock -from petals.bloom.modeling_utils import BloomConfig def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e05603357..356793065 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -16,9 +16,9 @@ from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from transformers import BloomConfig from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block -from petals.bloom.modeling_utils import BloomConfig from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 0c8714b2a..2bcd3409c 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -9,9 +9,9 @@ import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock -from petals.bloom.modeling_utils import BloomConfig from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_8bit import replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 87da88d48..dcc2b31cf 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,6 +1,6 @@ import pytest import torch -from test_utils import MODEL_NAME, INITIAL_PEERS +from test_utils import MODEL_NAME from petals.client import DistributedBloomConfig from petals.server.throughput import measure_compute_rps @@ -8,7 +8,7 @@ @pytest.mark.forked def test_throughput_basic(): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) throughput = measure_compute_rps( config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10 ) From 7f7f5dc0b50dbb2f39fdf17a33c8b9b26288e440 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 10:37:36 +0300 Subject: [PATCH 19/24] Update src/petals/server/handler.py Co-authored-by: Max Ryabinin --- src/petals/server/handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index a7588bb85..f9fcf1d76 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -148,7 +148,7 @@ async def rpc_inference( hidden_states[:, : prompt.shape[1]] += prompt if hidden_states.numel() == 0: continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user to pre-allocate cache or check that server *can* allocate that cache + # when user wants to pre-allocate cache or check that server *can* allocate that cache cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length assert isinstance( From 7103268d83e411cf85a0c3482c5b459d43f05f83 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 10:38:20 +0300 Subject: [PATCH 20/24] Update src/petals/bloom/modeling_utils.py Co-authored-by: Max Ryabinin --- src/petals/bloom/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index 08ca40dc2..9ae60dc25 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -66,7 +66,7 @@ def chunked_forward(self, hidden_states): num_embeddings = self.word_embeddings.num_embeddings hidden_states = hidden_states.float() - output = torch.zeros(*hidden_states.shape[:-1], num_embeddings) + output = torch.empty(*hidden_states.shape[:-1], num_embeddings) for i in range(0, num_embeddings, self.chunk_size): chunk = word_embeddings[i : i + self.chunk_size].float() From fa632a31e598286bd505b73fb5e434b508c63223 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 13 Dec 2022 10:40:14 +0300 Subject: [PATCH 21/24] cleanup --- src/petals/bloom/from_pretrained.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 9e039d61f..518a013ff 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -24,7 +24,6 @@ CLIENT_BRANCH = "main" BLOCK_BRANCH_PREFIX = "block_" -USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False} def load_pretrained_block( From 110a3072ec8c7d867f8c3d54faedbce5687fbc0b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 10:41:56 +0300 Subject: [PATCH 22/24] cleanup --- src/petals/bloom/from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index 518a013ff..5e42199bd 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -34,7 +34,7 @@ def load_pretrained_block( use_auth_token: Optional[str] = None, cache_dir: Optional[str] = None, ) -> WrappedBloomBlock: - """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it.""" + """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) From a24659fa3606e9be9d7adde00cabcb1a01c98efe Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 10:42:10 +0300 Subject: [PATCH 23/24] Update src/petals/server/backend.py Co-authored-by: Max Ryabinin --- src/petals/server/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index dffc7a753..f1b460dee 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -16,7 +16,7 @@ class TransformerBackend(ModuleBackend): - """A wrapper for a bloom block that can process requests for bloom layer forward, backward and inference""" + """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference""" def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): super().__init__(*args, **kwargs) From a61c5bb8996d864eefceb1ed21c60ea21db394ea Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 13 Dec 2022 10:50:51 +0300 Subject: [PATCH 24/24] check transformers version --- src/petals/bloom/block.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/petals/bloom/block.py b/src/petals/bloom/block.py index 4f0735303..f4d50be51 100644 --- a/src/petals/bloom/block.py +++ b/src/petals/bloom/block.py @@ -3,11 +3,16 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b See commit history for authorship. """ +import os from typing import Optional, Tuple import torch.nn.quantized.dynamic.modules.linear +import transformers from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor +if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): + assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1" + class WrappedBloomBlock(BloomBlock): def forward(