-
Notifications
You must be signed in to change notification settings - Fork 690
[feature]Support model loading from cache #3857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
201ed8a to
818ec42
Compare
| self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size] | ||
|
|
||
| if layer.fd_config.load_config.load_choices == "default_v1": | ||
| if self.quant_config.is_checkpoint_bf16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改动的原因是什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为要cache里面的权重是量化好的,需要使用 离线量化的权重
| time_after_load = time.time() | ||
| logger.info(f"Model loading took {time_after_load - time_before_load} seconds") | ||
| return result | ||
| def paddle_weight_iterator(paddle_file_list: list[str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle_weight_iterator -> pdparams_weight_iterator,以后paddle的weight说不定换成另外一种格式了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| cache_dir = None | ||
| enable_cache = False | ||
| if envs.FD_ENABLE_MODEL_CACHE: | ||
| model_cache_path = os.path.join(fd_config.model_config.model, cache_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有的cache_xxx字样换成weight_cache_xxx,取名要具体,cache太笼统了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也支持多机save cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该支持的,多机每台机器现在都有一份权重吧?
| _save_model(model.state_dict(), os.path.join(tp_cache_dir, "cache.pdparams")) | ||
| logger.info(f"Saving model to {cache_dir}") | ||
| else: | ||
| logger.warning("skip saving") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么在skip saving呢?日志信息要写全一点,这么写除了你估计没人能看懂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
新功能增加cache机制,开启后第一次loading会额外有段耗时save cache,后续会默认使用cache进行loading
额外占用 量化后内存
4卡 wint4 测试
os.environ['FD_ENABLE_MODEL_LOAD_CACHE'] = '1'
或者
export FD_ENABLE_MODEL_LOAD_CACHE=1