Skip to content

Construct causal mask on-the-fly#493

Closed
andyrdt wants to merge 5 commits intoTransformerLensOrg:devfrom
andyrdt:attn_mask
Closed

Construct causal mask on-the-fly#493
andyrdt wants to merge 5 commits intoTransformerLensOrg:devfrom
andyrdt:attn_mask

Conversation

@andyrdt
Copy link
Copy Markdown
Contributor

@andyrdt andyrdt commented Jan 24, 2024

Description

Previously, we were allocating causal masks of size (n_ctx, n_ctx) for every instantiation of AbstractAttention, where n_ctx corresponds to the maximum context length.

For models with a large maximum context length, this leads to wasteful memory consumption.

This PR

I tried to make the change as conservatively as possible - I took the existing logic for creating the causal mask from AbstractAttention.__init__, and put it in AbstractAttention.apply_causal_mask. The causal mask is constructed at inference time, and its shape is (cur_ctx_length, cur_ctx_length), which is generally much smaller than (n_ctx, n_ctx).

I think the causal mask is light-weight enough that this should not cause performance issues. However, there is opportunity for further optimization: we could have a causal mask buffer tied to each instantiation of AbstractAttention, initialize it with some shape (maybe (128, 128)), and then increase it as needed (i.e. if we're doing a forward pass and the ctx_len is 129, then we can regenerate the buffer to be shape (256, 256)). This solution would avoid constructing new masks each forward pass. I didn't implement this here, but we can explore it if we feel it's necessary.

Fixes #479

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@andyrdt
Copy link
Copy Markdown
Contributor Author

andyrdt commented Jan 24, 2024

Having written out the alternative solution of having a cached attention mask that grows as needed, I'm thinking maybe that's better..

It does have the following drawback: if you run a very long sequence, the model will construct and cache a very large causal mask, and this will be cached for the model's lifetime. But this doesn't seem so bad (thanks @collingray for pointing this out) - if the mask can fit on the device initially, then it's probably fine to cache it for subsequent usage as well.

@andyrdt
Copy link
Copy Markdown
Contributor Author

andyrdt commented Jan 25, 2024

I ran the following benchmarks to measure perf impact. The difference in perf doesn't seem significant to me, so I think this simple implementation seems ok. Let me know if you think there's some other benchmark that would be good to check.

#%%
import torch
import time
import gc
from tqdm import tqdm
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained(
    'qwen-7b-chat',
    device='cuda',
    fp16=True,
    dtype=torch.float16,
    fold_ln=False,
    center_writing_weights=False, 
    center_unembed=False,
)
tokenizer = model.tokenizer
torch.set_grad_enabled(False)

#%%
# Testing forward pass perf

runtimes = []

for i in tqdm(range(50)):
    rand_toks = torch.randint(0, model.cfg.d_vocab-20, (1, 2048))

    start = time.time()
    logits = model(rand_toks)
    end = time.time()
    runtimes.append(end-start)

    torch.cuda.empty_cache(); gc.collect()

print(f"Mean: {sum(runtimes)/len(runtimes)} seconds")
print(f"Std: {torch.std(torch.tensor(runtimes))}")

# new impl (on-the-fly):
# Mean: 0.6870033359527588 seconds
# Std: 0.034684497863054276

# old impl (persistent buffer):
# Mean: 0.6738856601715087 seconds
# Std: 0.023556165397167206

#%%
# Testing generate perf

runtimes = []

for i in tqdm(range(20)):
    rand_toks = torch.randint(0, model.cfg.d_vocab-20, (1, 1024))

    start = time.time()
    generation = model.generate(rand_toks, max_new_tokens=32, stop_at_eos=False)
    end = time.time()
    runtimes.append(end-start)

    torch.cuda.empty_cache(); gc.collect()

print(f"Mean: {sum(runtimes)/len(runtimes)} seconds")
print(f"Std: {torch.std(torch.tensor(runtimes))}")

# new impl (on-the-fly):
# Mean: 3.7325199961662294 seconds
# Std: 0.13069213926792145

# old impl (persistent buffer):
# Mean: 3.7241495132446287 seconds
# Std: 0.21790307760238647

@alan-cooney
Copy link
Copy Markdown
Collaborator

Thanks for looking into this!

I guess the most efficient way would be to construct it once per model rather than once per head? However this would potentially break some forms of model parallelism (e.g. with deepspeed)

@alan-cooney
Copy link
Copy Markdown
Collaborator

Also pinged you directly with a potential hacky (buy more efficient) fix using a static property

@bryce13950
Copy link
Copy Markdown
Collaborator

@andyrdt & @alan-cooney Is there any recollection on where this was? @andyrdt I just merged your branch to the most recent main branch. If you remember the advice Alan gave you, and you have time to implement it, we can get this merged relatively quickly. Otherwise, if you want to convey that information here, I am happy to make the changes, and get this into the main branch.

@andyrdt
Copy link
Copy Markdown
Contributor Author

andyrdt commented Apr 9, 2024

Hi @bryce13950 - thanks for pinging on this.

The currently-implemented solution in this PR is to construct attention masks for each attention component (i.e. at each layer) on-the-fly. This solution is simple, safe, and doesn't impact perf much (see benchmarks above), but feels a bit hacky since we're reconstructing/reallocating the same mask across many layers.

I think probably the best solution would be to construct a single attention mask at the model level at the beginning of a forward pass, and then pass this attention mask around. This is what I see most recent HuggingFace model implementations do.

For example, the HF Llama implementation constructs a mask at the model level inside of forward(), and then plumbs it through the forward() of subcomponents, eventually passing it to the attention component’s forward(..., attention_mask, ...).

There is a complication here with certain models that don’t use the same attention mask for each layer (e.g. GPT-neo, which has alternating local/global attention for each layer). One way to deal with this would be to pass the global attention mask around, and then modify it locally in attention components that are configured to use local attention.

Let me know what you think about this proposed solution. Also please feel free to take a stab at implementing it (either editing this PR or submitting a new one) - I don't think I'll have time to properly revisit this over the next couple of weeks.

@bryce13950
Copy link
Copy Markdown
Collaborator

Thanks for getting the details in here. I am in the process of getting all PRs up to date with the main branch, and merging whatever is ready or wrapping up some quick final touches on anything. After that, I want to go through issues as well in order to clean up the irrelevant ones, and organize the rest. This PR definitely is looking like it is going to be a bit more time consuming than some of the other PRs so far.
When I am done with these upkeep tasks, if this is still pending, then I can definitely get back to it, and wrap it up. If you find time in the next couple weeks, and there still hasn't been any movement on this, then any time you can give to it will definitely help. If you do get it revised, then I will make sure to merge it right away. If not, then it will just have to wait a little bit longer, but not the end of the world.

@bryce13950 bryce13950 marked this pull request as draft April 13, 2024 18:53
@bryce13950 bryce13950 changed the base branch from main to dev May 23, 2024 00:21
@andyrdt
Copy link
Copy Markdown
Contributor Author

andyrdt commented Aug 7, 2024

abandoning this pr

@andyrdt andyrdt closed this Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Proposal] Memory efficient causal mask implementation

3 participants