Conversation
There was a problem hiding this comment.
Very cool! In particular the sharded part looks nice. Some comments (bear in my mind the review was sequential, so some earlier comments might be irrelevant):
- I think you could write everything in another file, as If your version is faster we'll just use yours.
- I think the notion of shard_size has to be enforced. Having an upper bound could make tuning much simpler.
- I think byte consideration make sense if for a specific amount of bytes, tokenizer run at the same speed, workers run at the same speed. Otherwise I'm in favor of using lines, which should make the code easier?
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | ||
| self.cache = {} | ||
|
|
||
| self.cache2 = {} |
There was a problem hiding this comment.
Can you add a comment on what that is? I'd rename cache to token_cache, and cache2 as normalisation_cache
There was a problem hiding this comment.
the argument -batch_shard_size sets the batch size, which is no 10MB, but we can vary it.
There was a problem hiding this comment.
cache was the original name used by the original GPT2 implementation so I'll keep it the same, but I'll change cache2 to something more meaninful.
There was a problem hiding this comment.
As you want, I feel it should be okayto mute existing code in order to improve readability.
| bpe_tokens = [] | ||
| if sys.version_info[0] == 2: | ||
| for token in re.findall(self.pat, text): | ||
| token = ''.join(self.byte_encoder[ord(b)] for b in token) |
There was a problem hiding this comment.
I believe you're missing the step where you convert token to bpe. Otherwise this always return empty list no?
| for token in re.findall(self.pat, text): | ||
| token = ''.join(self.byte_encoder[ord(b)] for b in token) | ||
| return bpe_tokens | ||
| cache2 = self.cache2 |
| cache2 = self.cache2 | ||
| for token in re.findall(self.pat, text): | ||
| if token in cache2: | ||
| ret = cache2[token] |
There was a problem hiding this comment.
Ah I see what cache2 is now, it's essentially the same as cache + normalization. Instead of this, can we seperate cache and cache2? Ie cache2 can be the normalisation cache. Typically:
cache2[token]= ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))cachecan be improved by having as values a list instead of a string (though it might be used in order to force value passe, ie immutable elements)
Therefore the you would use cache like the following:
if token in cache2:
normalised_token = cache2[token] #normalise it
else:
normalised_token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
if len_orig_token< 10:
cache2[token] = normalised_token
# If you wanted to have something faster, you could force self.bpe to return a list of bpes. It should have the same speed as your version as soon as you do this.
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')]
bpe_tokens.extend(ret)Let me know what you think
There was a problem hiding this comment.
I changed to normalized_cache. I don't follow the example you wrote above. Let's discuss.
| def encode_old(self, text): | ||
| return self.convert_tokens_to_ids(self.tokenize_old(text)) |
There was a problem hiding this comment.
Safe to say that once we benchmark everything, we won't have an "old" version no? If yours is faster and it does the same thing, I'd advocate to remove the previous one.
There was a problem hiding this comment.
Correct, I'm just leaving so we can do benchmark of old vs. new tokenize. if new tokenize is faster, we can remove all the tokenize_old and encode_old reference.
There was a problem hiding this comment.
You also need to benchmark memory footprint I'd say.
There was a problem hiding this comment.
Proposal to do this in JZ? Is there a memory profile function?
There was a problem hiding this comment.
So I'm not an expert on this ... I'll try to think of something. Otherwise let's assume that we have lots of memory (which is safe I think)? Concerning the speed, your colab suggested that this works faster? If so I'd make it a seperate PR (Improve tokenizer). so we can merge seperatly.
| if not line: | ||
| file_segs.append((file_pos, file_size-file_pos)) | ||
| break | ||
| seg_len += len(line) |
There was a problem hiding this comment.
Shouldn't you get the line byte size, instead of its character size, or is it always the same thing (I'm thinking of things like emojis and such)? Also it's a shame that you're not obtaining an upper bound on the memory, ie shards are all bigger than shard_size which is counter intuitive IMO. (Though I admit, your implementation of computing the end cursor looks really good).
There was a problem hiding this comment.
Hmm.. I think this is bytes. When you read a file in binary it always reads in bytes right?
There was a problem hiding this comment.
Ah yes, I did not see 'b'. My bad!
| if file_size-(file_pos+seg_len) < shard_size: | ||
| file_segs.append((file_pos, file_size-file_pos)) | ||
| break |
There was a problem hiding this comment.
This should be handled at next iteration with your code at line 390. because you're merging two shards imo
(file_pos, f.tell() - file_pos) and (f.tell(), file_size - f.tell())
There was a problem hiding this comment.
Let's discuss. This is the case where we merge one shard with another shard that has the rest of the line. But we are now at the end of the file.
| output_idx_files = {} | ||
| builders = {} | ||
| for key in self.args.json_keys: | ||
| output_bin_files[key] = "{}_{}_{}_shard_{}.bin".format(self.args.output_prefix, |
There was a problem hiding this comment.
You could pass a shard_id to your function, instead of taking job[0] as your id.
There was a problem hiding this comment.
You are right. We can pass any unique ID. One benefit of passing the position is that if we didn't return the shard file name in order, we could have just glob.glob("shard.") and sort by the position number and cat everything together to get the data in order. Btw, in my other code where I use this technique, I actually do os.system("cat ... > output.")
There was a problem hiding this comment.
So I was really more thinking of values between 0 and number of shards. Though after thinking about it, it might be cool to use job[0] to check on the original file. you can ignore my comment then.
| vocab_size = build_tokenizer(args).vocab_size | ||
|
|
||
| pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) | ||
| shards = get_file_shard_ranges(args.input, args.batch_shard_size*1000000, num_proc=args.workers) |
There was a problem hiding this comment.
Huh? I thought we want to do in mbytes. Is it more efficient to do 1024 * 1024?
There was a problem hiding this comment.
I think it depends on the convention taken, but in the logging of writing speed, 1MB = 1024 * 1024 B
There was a problem hiding this comment.
Ah ok. Let's follow one convention then.
| return args | ||
|
|
||
|
|
||
| def get_file_shard_ranges(input_file_path, shard_size, num_proc, num_shards=None): |
There was a problem hiding this comment.
You should write this as an iterator, I think .map makes sure of creating a thread safe iterator on top. This will reduce your memory footprint (though not by a lot as it's just indices)
There was a problem hiding this comment.
Thanks for the comment. If I have time I'll do this.
|
|
||
| # this does writing to shard files | ||
| # if _iter is true, read line by line, instead read in blocks | ||
| def encode_shard(self, job): |
There was a problem hiding this comment.
| def encode_shard(self, job): | |
| def read_encode_and_write(self, job): |
There was a problem hiding this comment.
Also isn't that just encode_batch + write on a file directly. If so you could use that function instead or rewriting everything.
| if file_size - file_pos <= shard_size: | ||
| file_segs.append((file_pos, file_size-file_pos)) | ||
| break | ||
| f.seek(file_pos+shard_size, 0) |
There was a problem hiding this comment.
Would this be equivalent to f.seek(shard_size)?
There was a problem hiding this comment.
No because it's in a loop, and file_pos is moving.
There was a problem hiding this comment.
yes, but so is the current cursor in file?
There was a problem hiding this comment.
Not sure I understand. Seek does a seek from some point (in this case 0). You could do seek(shard_size, file_pos) I suppose.
There was a problem hiding this comment.
Okay turns out f.seek(shard_size, 1) should use relative position, however it doesn't work on my setup something about python 3.8. so let's keep yours I guess. What I like by the relative one, is basically you're telling it to "move forward ${shard_size} bytes" and not "go to ${file_pos + shard_size}", though equivalent in this case.
| if not line: | ||
| file_segs.append((file_pos, file_size-file_pos)) | ||
| break |
There was a problem hiding this comment.
Can you add a comment saying "empty lines can only be found at the end of the file"?
| pool.close() | ||
| pool.join() | ||
| # let's merge | ||
| # we could have done the merge incrementally but the dataset builder doesn't really do incremental merge efficiently |
There was a problem hiding this comment.
Why not? I think one of the strenght of your implementation is you can write shards, merge them, and remove shards as you go.
There was a problem hiding this comment.
Ok. I made it incremental and delete the shard files as we go. This should reduce disk space and in theory be a bit faster
|
This WIP is closed in favor of this PR #37 which has been merged. |
Draft PR to do hopefully better preprocessing with multi-process