Skip to content

[Proposal] Memory efficient causal mask implementation #479

@andyrdt

Description

@andyrdt

Proposal

[Relatively minor proposal - considered making it a bug, but it's not really a bug.]

In the initialization of each Attention module, we register a causal_mask buffer. This buffer is a boolean tensor of shape (self.cfg.n_ctx, self.cfg.n_ctx).

This is quite inefficient for 2 reasons:

  1. Most times, the prompt context length is much smaller than self.cfg.n_ctx (which represents the maximum context length).
  2. The same buffer is stored for every layer.

This hasn't really been a visible issue so far with models with smallish context lengths. But consider a model like Qwen 72B, which has max context length of 32768 and 80 layers. With the current implementation, there will be a boolean tensor of shape (32768, 32768) initialized for each layer, resulting in 32768 * 32768 * 1 byte * 80 layers ~= 86 GB of overhead.

As a temporary fix, we can just cap n_ctx on models to be less than some reasonable value (2048 or 4096). But I think the ideal solution is just to compute the attention mask on the fly, and have it be the size of the particular context length.

Note that the same inefficiency exists with rotary embeddings (we precompute sin and cos tensors of length n_ctx). But it's not nearly as bad since they grow O(n_ctx), whereas the mask grows O(n_ctx^2).

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    complexity-moderateModerately complicated issues for people who have intermediate experience with the codeenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions