Skip to content

WIP: Faster preprocess multi#32

Closed
huu4ontocord wants to merge 15 commits intomainfrom
faster_preprocess_multi
Closed

WIP: Faster preprocess multi#32
huu4ontocord wants to merge 15 commits intomainfrom
faster_preprocess_multi

Conversation

@huu4ontocord
Copy link
Contributor

Draft PR to do hopefully better preprocessing with multi-process

@thomasw21 thomasw21 changed the title Faster preprocess multi WIP: Faster preprocess multi Jul 29, 2021
Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! In particular the sharded part looks nice. Some comments (bear in my mind the review was sequential, so some earlier comments might be irrelevant):

  • I think you could write everything in another file, as If your version is faster we'll just use yours.
  • I think the notion of shard_size has to be enforced. Having an upper bound could make tuning much simpler.
  • I think byte consideration make sense if for a specific amount of bytes, tokenizer run at the same speed, workers run at the same speed. Otherwise I'm in favor of using lines, which should make the code easier?

self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}

self.cache2 = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on what that is? I'd rename cache to token_cache, and cache2 as normalisation_cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the argument -batch_shard_size sets the batch size, which is no 10MB, but we can vary it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache was the original name used by the original GPT2 implementation so I'll keep it the same, but I'll change cache2 to something more meaninful.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you want, I feel it should be okayto mute existing code in order to improve readability.

bpe_tokens = []
if sys.version_info[0] == 2:
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[ord(b)] for b in token)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you're missing the step where you convert token to bpe. Otherwise this always return empty list no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[ord(b)] for b in token)
return bpe_tokens
cache2 = self.cache2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

cache2 = self.cache2
for token in re.findall(self.pat, text):
if token in cache2:
ret = cache2[token]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what cache2 is now, it's essentially the same as cache + normalization. Instead of this, can we seperate cache and cache2? Ie cache2 can be the normalisation cache. Typically:

  • cache2[token] = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
  • cache can be improved by having as values a list instead of a string (though it might be used in order to force value passe, ie immutable elements)

Therefore the you would use cache like the following:

if token in cache2:
  normalised_token = cache2[token] #normalise it
else:
  normalised_token =  ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
  if len_orig_token< 10:
    cache2[token] = normalised_token

# If you wanted to have something faster, you could force self.bpe to return a list of bpes. It should have the same speed as your version as soon as you do this.
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')]
bpe_tokens.extend(ret)

Let me know what you think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to normalized_cache. I don't follow the example you wrote above. Let's discuss.

Comment on lines +301 to +302
def encode_old(self, text):
return self.convert_tokens_to_ids(self.tokenize_old(text))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Safe to say that once we benchmark everything, we won't have an "old" version no? If yours is faster and it does the same thing, I'd advocate to remove the previous one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I'm just leaving so we can do benchmark of old vs. new tokenize. if new tokenize is faster, we can remove all the tokenize_old and encode_old reference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to benchmark memory footprint I'd say.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposal to do this in JZ? Is there a memory profile function?

Copy link
Member

@thomasw21 thomasw21 Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm not an expert on this ... I'll try to think of something. Otherwise let's assume that we have lots of memory (which is safe I think)? Concerning the speed, your colab suggested that this works faster? If so I'd make it a seperate PR (Improve tokenizer). so we can merge seperatly.

if not line:
file_segs.append((file_pos, file_size-file_pos))
break
seg_len += len(line)
Copy link
Member

@thomasw21 thomasw21 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you get the line byte size, instead of its character size, or is it always the same thing (I'm thinking of things like emojis and such)? Also it's a shame that you're not obtaining an upper bound on the memory, ie shards are all bigger than shard_size which is counter intuitive IMO. (Though I admit, your implementation of computing the end cursor looks really good).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. I think this is bytes. When you read a file in binary it always reads in bytes right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I did not see 'b'. My bad!

Comment on lines +400 to +402
if file_size-(file_pos+seg_len) < shard_size:
file_segs.append((file_pos, file_size-file_pos))
break
Copy link
Member

@thomasw21 thomasw21 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled at next iteration with your code at line 390. because you're merging two shards imo
(file_pos, f.tell() - file_pos) and (f.tell(), file_size - f.tell())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss. This is the case where we merge one shard with another shard that has the rest of the line. But we are now at the end of the file.

output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}_shard_{}.bin".format(self.args.output_prefix,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could pass a shard_id to your function, instead of taking job[0] as your id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. We can pass any unique ID. One benefit of passing the position is that if we didn't return the shard file name in order, we could have just glob.glob("shard.") and sort by the position number and cat everything together to get the data in order. Btw, in my other code where I use this technique, I actually do os.system("cat ... > output.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was really more thinking of values between 0 and number of shards. Though after thinking about it, it might be cool to use job[0] to check on the original file. you can ignore my comment then.

vocab_size = build_tokenizer(args).vocab_size

pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
shards = get_file_shard_ranges(args.input, args.batch_shard_size*1000000, num_proc=args.workers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it's more 1024 * 1024

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh? I thought we want to do in mbytes. Is it more efficient to do 1024 * 1024?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the convention taken, but in the logging of writing speed, 1MB = 1024 * 1024 B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok. Let's follow one convention then.

return args


def get_file_shard_ranges(input_file_path, shard_size, num_proc, num_shards=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should write this as an iterator, I think .map makes sure of creating a thread safe iterator on top. This will reduce your memory footprint (though not by a lot as it's just indices)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. If I have time I'll do this.


# this does writing to shard files
# if _iter is true, read line by line, instead read in blocks
def encode_shard(self, job):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def encode_shard(self, job):
def read_encode_and_write(self, job):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also isn't that just encode_batch + write on a file directly. If so you could use that function instead or rewriting everything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

if file_size - file_pos <= shard_size:
file_segs.append((file_pos, file_size-file_pos))
break
f.seek(file_pos+shard_size, 0)
Copy link
Member

@thomasw21 thomasw21 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be equivalent to f.seek(shard_size)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because it's in a loop, and file_pos is moving.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but so is the current cursor in file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand. Seek does a seek from some point (in this case 0). You could do seek(shard_size, file_pos) I suppose.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay turns out f.seek(shard_size, 1) should use relative position, however it doesn't work on my setup something about python 3.8. so let's keep yours I guess. What I like by the relative one, is basically you're telling it to "move forward ${shard_size} bytes" and not "go to ${file_pos + shard_size}", though equivalent in this case.

Comment on lines +396 to +398
if not line:
file_segs.append((file_pos, file_size-file_pos))
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment saying "empty lines can only be found at the end of the file"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

pool.close()
pool.join()
# let's merge
# we could have done the merge incrementally but the dataset builder doesn't really do incremental merge efficiently
Copy link
Member

@thomasw21 thomasw21 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? I think one of the strenght of your implementation is you can write shards, merge them, and remove shards as you go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I made it incremental and delete the shard files as we go. This should reduce disk space and in theory be a bit faster

@huu4ontocord
Copy link
Contributor Author

This WIP is closed in favor of this PR #37 which has been merged.

@jaketae jaketae deleted the faster_preprocess_multi branch November 24, 2021 06:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants