From acedd87d024cfd933f935f4890d29f9884d3e8f8 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Tue, 30 Mar 2021 20:58:08 -0700 Subject: [PATCH 1/9] Add initial script for finetuning MLM models with accelerate --- .../language-modeling/run_mlm_no_trainer.py | 506 ++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100755 examples/language-modeling/run_mlm_no_trainer.py diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py new file mode 100755 index 000000000000..fa1593467a7f --- /dev/null +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset without. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=masked-lm +""" +# You can also adapt this script on your own mlm task. Pointers for this are left as comments. + +import argparse +import logging +import math +import os +import random + +import datasets +from datasets import load_dataset, load_metric +from torch.utils.data.dataloader import DataLoader +from tqdm.auto import tqdm + +import transformers +from accelerate import Accelerator +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AdamW, + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + PretrainedConfig, + SchedulerType, + get_scheduler, + set_seed, +) + + +logger = logging.getLogger(__name__) +# You should update this to your particular problem to have better documentation of `model_type` +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--max_length", + type=int, + default=128, + help=( + "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," + " sequences shorter will be padded if `--pad_to_max_lengh` is passed." + ), + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "max_seq_length", + type=int, + default=None, + help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", + ) + parser.add_argument( + "line_by_line", + type=bool, + default=False, + help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", + ) + parser.add_argument( + "preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" + ) + + args = parser.parse_args() + + # Sanity checks + if args.task_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a task name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + return args + + +def main(): + args = parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator() + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + raw_datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForMaskedLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForMaskedLM.from_config(config) + + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + if args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length + if max_seq_length > 1024: + logger.warn( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." + ) + max_seq_length = 1024 + else: + if args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + + if args.line_by_line: + # When using line_by_line, we just tokenize each nonempty line. + padding = "max_length" if args.pad_to_max_length else False + + def tokenize_function(examples): + # Remove empty lines + examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] + return tokenizer( + examples["text"], + padding=padding, + truncation=True, + max_length=max_seq_length, + # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it + # receives the `special_tokens_mask`. + return_special_tokens_mask=True, + ) + + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not args.overwrite_cache, + ) + else: + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more + # efficient when it receives the `special_tokens_mask`. + def tokenize_function(examples): + return tokenizer(examples[text_column_name], return_special_tokens_mask=True) + + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + ) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of + # max_seq_length. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + total_length = (total_length // max_seq_length) * max_seq_length + # Split by chunks of max_len. + result = { + k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] + for k, t in concatenated_examples.items() + } + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a + # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value + # might be slower to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + ) + + train_dataset = tokenized_datasets["train"] + eval_dataset = tokenized_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Data collator + # This one will take care of randomly masking the tokens. + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability) + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader + ) + + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be + # shorter in multiprocess) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + + for epoch in range(args.num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if completed_steps >= args.max_train_steps: + break + + # model.eval() + # for step, batch in enumerate(eval_dataloader): + # outputs = model(**batch) + # predictions = outputs.logits.argmax(dim=-1) + # metric.add_batch( + # predictions=accelerator.gather(predictions), + # references=accelerator.gather(batch["labels"]), + # ) + + # eval_metric = metric.compute() + # logger.info(f"epoch {epoch}: {eval_metric}") + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + + +if __name__ == "__main__": + main() From f9767b0fb1e615075bc306efe2430f41333b1347 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Wed, 31 Mar 2021 07:41:03 -0700 Subject: [PATCH 2/9] Add evaluation metric calculation --- .../language-modeling/run_mlm_no_trainer.py | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index fa1593467a7f..8488e872b8d7 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -27,8 +27,9 @@ import os import random +import torch import datasets -from datasets import load_dataset, load_metric +from datasets import load_dataset from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm @@ -42,7 +43,6 @@ AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling, - PretrainedConfig, SchedulerType, get_scheduler, set_seed, @@ -80,15 +80,6 @@ def parse_args(): default=5, help="The percentage of the train set used as validation set in case there's no validation split", ) - parser.add_argument( - "--max_length", - type=int, - default=128, - help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," - " sequences shorter will be padded if `--pad_to_max_lengh` is passed." - ), - ) parser.add_argument( "--pad_to_max_length", action="store_true", @@ -484,17 +475,18 @@ def group_texts(examples): if completed_steps >= args.max_train_steps: break - # model.eval() - # for step, batch in enumerate(eval_dataloader): - # outputs = model(**batch) - # predictions = outputs.logits.argmax(dim=-1) - # metric.add_batch( - # predictions=accelerator.gather(predictions), - # references=accelerator.gather(batch["labels"]), - # ) - - # eval_metric = metric.compute() - # logger.info(f"epoch {epoch}: {eval_metric}") + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + outputs = model(**batch) + loss = outputs.loss * args.per_device_eval_batch_size + losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) + + losses = torch.cat(losses) + losses = losses[: len(eval_dataset)] + perplexity = math.exp(torch.mean(losses)) + + logger.info(f"epoch {epoch}: perplexity: {perplexity}") if args.output_dir is not None: accelerator.wait_for_everyone() From 3e635a6771a91f411bd2aff4c4a1aa99061d6df9 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 31 Mar 2021 17:55:35 +0000 Subject: [PATCH 3/9] Fix bugs --- examples/language-modeling/run_mlm_no_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index 8488e872b8d7..0f753e494572 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -160,35 +160,35 @@ def parse_args(): choices=MODEL_TYPES, ) parser.add_argument( - "max_seq_length", + "--max_seq_length", type=int, default=None, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", ) parser.add_argument( - "line_by_line", + "--line_by_line", type=bool, default=False, help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", ) parser.add_argument( - "preprocessing_num_workers", + "--preprocessing_num_workers", type=int, default=None, help="The number of processes to use for the preprocessing.", ) parser.add_argument( - "overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" ) parser.add_argument( - "mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" + "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" ) args = parser.parse_args() # Sanity checks - if args.task_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a task name or a training/validation file.") + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") else: if args.train_file is not None: extension = args.train_file.split(".")[-1] @@ -242,7 +242,7 @@ def main(): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) - if "validation" not in datasets.keys(): + if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( args.dataset_name, args.dataset_config_name, From dc42dedd11f9b0c273f258f35e00ff17374f6e16 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 31 Mar 2021 18:23:16 +0000 Subject: [PATCH 4/9] Use no_grad on evaluation --- examples/language-modeling/run_mlm_no_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index 0f753e494572..4c916df44515 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -478,7 +478,9 @@ def group_texts(examples): model.eval() losses = [] for step, batch in enumerate(eval_dataloader): - outputs = model(**batch) + with torch.no_grad(): + outputs = model(**batch) + loss = outputs.loss * args.per_device_eval_batch_size losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) From 79a3941bbab4da4069187191afc09b69fa1eba6d Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 31 Mar 2021 18:27:06 +0000 Subject: [PATCH 5/9] update script docstring --- examples/language-modeling/run_mlm_no_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index 4c916df44515..aa670ab90260 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset without. +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset without using HuggingFace Trainer. Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=masked-lm @@ -50,7 +50,6 @@ logger = logging.getLogger(__name__) -# You should update this to your particular problem to have better documentation of `model_type` MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) From 2a569997bdb2806f7cad4a8d1ac2586034a05de3 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 1 Apr 2021 00:32:59 +0530 Subject: [PATCH 6/9] Update examples/language-modeling/run_mlm_no_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- examples/language-modeling/run_mlm_no_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index aa670ab90260..33f720e19cea 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e5912b01c4651ef1cb05297650f10335c694a174 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Wed, 31 Mar 2021 12:04:34 -0700 Subject: [PATCH 7/9] PR feedback --- examples/language-modeling/run_mlm_no_trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index 33f720e19cea..fb06c2018dc2 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset without using HuggingFace Trainer. +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) +on a text file or a dataset without using HuggingFace Trainer. Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=masked-lm @@ -191,10 +192,10 @@ def parse_args(): else: if args.train_file is not None: extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." if args.validation_file is not None: extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) From 670f31171db5eb964277f747d71256ab588d6958 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Wed, 31 Mar 2021 12:16:22 -0700 Subject: [PATCH 8/9] Fix CI failure --- examples/language-modeling/run_mlm_no_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index fb06c2018dc2..f33462c74f2c 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -28,8 +28,8 @@ import os import random -import torch import datasets +import torch from datasets import load_dataset from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm From ca05accd05d177c6aea6e74e20a786d1fd4bdfba Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 1 Apr 2021 01:50:39 +0530 Subject: [PATCH 9/9] Update examples/language-modeling/run_mlm_no_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- examples/language-modeling/run_mlm_no_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language-modeling/run_mlm_no_trainer.py b/examples/language-modeling/run_mlm_no_trainer.py index f33462c74f2c..a943bfd4a715 100755 --- a/examples/language-modeling/run_mlm_no_trainer.py +++ b/examples/language-modeling/run_mlm_no_trainer.py @@ -481,7 +481,7 @@ def group_texts(examples): with torch.no_grad(): outputs = model(**batch) - loss = outputs.loss * args.per_device_eval_batch_size + loss = outputs.loss losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) losses = torch.cat(losses)