Skip to content

[trainer] remove --model_parallel#9451

Merged
sgugger merged 12 commits intohuggingface:masterfrom
stas00:revert-is_parallel-check
Jan 11, 2021
Merged

[trainer] remove --model_parallel#9451
sgugger merged 12 commits intohuggingface:masterfrom
stas00:revert-is_parallel-check

Conversation

@stas00
Copy link
Copy Markdown
Contributor

@stas00 stas00 commented Jan 7, 2021

Per @sgugger's request removing --model_parallel in trainer, as it was never tested or made to work with the trainer.

We will get back to it in the future.

This PR doesn't introduce breaking changes, since --model_parallel never worked (well other than in my MP PRs that have been parked for now, since they are very inefficient and we are looking for a better approach, rather than waste time on sorting those out).

@LysandreJik, @sgugger

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, LGTM! We should have been more attentive during the review, no harm done.

@sgugger for info, this was removed here: 9f675b0#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deL245-L248

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Jan 7, 2021

Thanks for putting it back. Since we're in a PR on this test alone, can we "fix" it to ignore the args.model_parallel argument? This argument will be removed/renamed (I'd prefer the first option as it's not useful) since peoples are confusing it with something that will enable DataParallel. The test can be replaced by model.is_parallelizable and model.parallel I believe, with the current API.

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

2 things:

  1. you must be referring to self.model_parallel? But it will be always False unless model.parallelize() is called!

    So while you can rename the argument, you can't remove it, the user needs to activate this explicitly and the trainer then must activate MP with model.parallelize()

    Wrt DataParallel. Why are we turning it on automatically in first place? Why not make it manual and call it --data_parallel - no more confusion. Loud and clear:

    • --model_parallel
    • --data_parallel
  2. As we discovered last night current trainer doesn't work at all with --model_parallel - see [trainer] deepspeed integration #9211 (comment) there is no activation of that parallel mode - nobody calls model.parallelize() so it's very broken

I change this code last night to;

        if self.args.model_parallel:
            if model.is_parallelizable:
                model.parallelize()
            else:
                raise ValueError(
                    f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
                )

and it doesn't work when I try:

rm -r output_dir; CUDA_VISIBLE_DEVICES=0,1 PYTHONPATH=../../src USE_TF=0 ./finetune_trainer.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --fp16 --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 4 --per_device_train_batch_size 4 --predict_with_generate --eval_steps 25000 --save_steps 25000 --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 1 --n_train 2 --n_val 2 --n_test 2 --do_predict --model_parallel

It doesn't look it ever worked...

i.e. MP works when setup up manually but doesn't work in trainer.

p.s. I tagged you on that discussion - not sure if you saw it.

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Jan 7, 2021

i.e. MP works when setup up manually but doesn't work in trainer.
As we discovered last night current trainer doesn't work at all with --model_parallel - see #9211 (comment) there is no activation of that parallel mode - nobody calls model.parallelize() so it's very broken

That's not a discovery on my side, that is exactly why I keep saying that the argument --model_parallel should be removed. It doesn't actually do anything and is confusing for the user. The call to model.parallelize() can always be done outside of Trainer IMO, which is why the test can be changed as suggested. We can think of integrating it inside the Trainer later, when the API is stable and actually used, for now I don't see the point of adding this.

Wrt DataParallel. Why are we turning it on automatically in first place? Why not make it manual and call it --data_parallel

That would be a big breaking change in the API, and beginners actually want to have the parallelism work out of the box when they have several GPUs, so I don't see why change something that works.

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

The call to model.parallelize() can always be done outside of Trainer IMO, which is why the test can be changed as suggested.

It doesn't work

Wrt DataParallel. Why are we turning it on automatically in first place? Why not make it manual and call it --data_parallel

That would be a big breaking change in the API, and beginners actually want to have the parallelism work out of the box when they have several GPUs, so I don't see why change something that works.

OK, then the flag should be there with the default On? Surely a user should be able not to run DP and it's not possible at the moment.

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

OK, so I did remove --model_parallel - no problem in trainer.py since I used model.is_parallelizable and model.parallel instead - and I now understand that the point is that the user has to activate model.parallelize() themselves before passing the model to the trainer - i.e. no examples scripts will support MP at the moment.

The problem is training_args.py - how do I deal with:

        if not self.model_parallel:
            train_batch_size = per_device_batch_size * max(1, self.n_gpu)
        else:
            train_batch_size = per_device_batch_size

self is args here, and there is no trainer object. Suggestions?

But I guess I need to first figure out how to make MP work in trainer at all, I doesn't look it was ever tried or tested. As it fails for me.

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

FWIW, --model_parallel works just fine with my Bart MP PR: #9384 (comment) in case someone needs it.

I suspect t5 MP wasn't tested/made to work with generate tools (beam search, etc.) - edit It works now in this PR #9323 - but super slow in beam search!

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

OK, I committed the bulk of it, and @sgugger will push some magic to deal with training_args.py

tests should be failing I think until he does that.

@stas00 stas00 changed the title [trainer] fix bad rebase - dropped code [trainer] remove --model_parallel Jan 7, 2021
@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

So now I can see I can jokingly blame my initial mistake on @sgugger since he wanted it removed all along and so I unconsciously did it during rebasing and he unconsciously saw this as the right thing to do during the review ;) It's all Freud's fault anyway ;)

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

I added a wrapped first, but it looked out of place so I introduced and documented a new attribute: self.is_model_parallel - hope it's loud and clear.

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 7, 2021

@sgugger, I must be doing something wrong - that docstring section of Important attributes that I started in model_wrapped PR gets wrapped all funny - so I tried to add bullets and then it gets all messed up, as it bunches it all up into one paragraph. If I add new lines then make docs fails. Your magic touch is needed. Thank you.

@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 8, 2021

and here is why I removed init=False in a7a3921

The tests were failing with:

TypeError: __init__() got an unexpected keyword argument '_n_gpu'

https://circle-production-customer-artifacts.s3.amazonaws.com/picard/forks/5bdabdd888af1f000130874a/278[…]cc8b6d6c390aab800d0cc1350f731a19529ac82f48

@sgugger sgugger requested a review from LysandreJik January 8, 2021 15:23
@stas00
Copy link
Copy Markdown
Contributor Author

stas00 commented Jan 8, 2021

Thank you for fixing the docs, @sgugger!

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LGTM!

Comment on lines +271 to +274
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
self.is_model_parallel = True
else:
self.is_model_parallel = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@sgugger sgugger merged commit 33b7422 into huggingface:master Jan 11, 2021
@stas00 stas00 deleted the revert-is_parallel-check branch January 11, 2021 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants