-
Notifications
You must be signed in to change notification settings - Fork 6
Qodana simulation model #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
136 commits
Select commit
Hold shift + click to select a range
53dbe88
Merge branch 'main-upd' into develop
nbirillo 16937ba
Add possibility to evaluate csv files
nbirillo 136ed89
Filter solutions by language
nbirillo 78e7201
Optimize evaluation script, fix tests
nbirillo 48e3408
Add readme, add distribute grades
nbirillo 80ea546
Added project template for Java
GirZ0n db30732
Added DatasetMarker
GirZ0n 995aaea
Add tests
nbirillo 30859c0
Add diffs finder between two dfs
nbirillo f21cdf8
Delete unnecessary files
nbirillo 8c85e36
Code refactoring
GirZ0n b1fa21b
Added some words
GirZ0n 216aef4
Merge branch 'qodana' into roberta-model
dariadiatlova ee9e4fe
merge prep
dariadiatlova 61d7199
fixed merge conflicts
dariadiatlova a56eeae
Fixed merge conflict
dariadiatlova 5929f46
Small code refactoring
GirZ0n 1a34b50
Added new requirements
GirZ0n ecbcdf4
Added ID to ColumnName
GirZ0n 92f9a8a
Added README.md
GirZ0n 3ef6a42
Added default value for --chunk-size
GirZ0n 3f028bd
Merge remote-tracking branch 'origin/qodana' into qodana
GirZ0n 7cd79e0
parse qodana output
nbirillo 5c4b85a
Merge remote-tracking branch 'origin/develop' into develop
nbirillo ec6b477
Merge branch 'develop' into qodana
nbirillo a8b80c0
Update README.md
GirZ0n 75cac7b
Change qodana scipt output
nbirillo dd9d502
Merge remote-tracking branch 'origin/develop' into qodana
GirZ0n 6915254
Merge remote-tracking branch 'origin/qodana' into qodana
nbirillo f0c098b
Merge branch 'qodana' into fix/qodana-output
nbirillo c428b78
Fix a bug with qodana
nbirillo 8489c9d
Fix a bug with path to the gradle project
nbirillo 395cc6f
Add a script for filtering inspections
nbirillo 371f985
Fixed PR issues
GirZ0n 0dab1b7
Fix/qodana output (#33)
nbirillo f9b418d
Added is_java function
GirZ0n 96c0518
1) Added copy_directory and copy_file functions;
GirZ0n 38a936a
Removed python_on_whales dependency
GirZ0n 235e60f
Fixed some PR issues
GirZ0n 5faa46c
Merge branch 'fix/qodana-output' into qodana
GirZ0n e7c01a9
Merge branch 'qodana' into qodana-handlers
GirZ0n 082ca6d
Qoadana handlers/get unique inspections (#35)
nbirillo e413678
Update README.md
dariadiatlova 2b054f8
Added train, evaluationa and dataset preprocessing script for dataset
dariadiatlova f6b869d
resolve merge conflicts
dariadiatlova f1e7961
resolve merge conflicts
dariadiatlova 57d796b
fix merge conflicts
dariadiatlova 6b9f493
updates req and whitelist
dariadiatlova 9de5a23
added option to set parameter in encode_data script
dariadiatlova 3b82f8d
improved dataset class
dariadiatlova b41531d
updated requirenments
dariadiatlova 979a918
updated whitelist
dariadiatlova ef62341
changed acc on f1 score and refactoring
dariadiatlova cff3619
updated whitelist
dariadiatlova 26106e9
fixed typing and styles
dariadiatlova 9902efb
fixed style
dariadiatlova cfa2bd4
fixed style
dariadiatlova 521b168
fixed dataset type
dariadiatlova cdb68ca
removed extra to(device)
dariadiatlova 09a51db
small architecture changes
dariadiatlova 6b438b8
fixed styles
dariadiatlova dfa44fa
small changes in help section
dariadiatlova e56d9c5
fixed help section and added readme
dariadiatlova 1ac4408
Update README.md
dariadiatlova 1681db2
small fixe in the help section of train config
dariadiatlova 7e6901e
small fix – added short name in train config
dariadiatlova 363ccd0
changed full name
dariadiatlova 84e7dc8
Variable names refactoring
dariadiatlova 7cd3bb0
Added README
dariadiatlova 4787f3e
Update README.md
dariadiatlova 17bec35
Update README.md
dariadiatlova 31b9c7f
Update README.md
dariadiatlova d57938a
Update README.md
dariadiatlova f2e9d18
Variable name refactoring
dariadiatlova 7caa477
Merge remote-tracking branch 'origin/roberta-model' into roberta-model
dariadiatlova d66001c
Update README.md
dariadiatlova 9643a38
Update README.md
dariadiatlova c0a9488
Update evaluation_config.py
dariadiatlova c21dec0
fixed gpu usage
dariadiatlova 4b379ba
fixed label choice while computing metric
dariadiatlova aa20d14
updated metric class instance
dariadiatlova 4b634a4
updated f1-score computation
dariadiatlova 30188d4
updated metric
dariadiatlova 7e6d6a8
fixed df shape
685c894
debugging - small fix
dariadiatlova e339944
small fix - debugging
dariadiatlova bb554f6
merge preparation
dariadiatlova 21abcdc
fix merge conflicts
dariadiatlova 94a3689
updated whitelist
dariadiatlova 28734ae
Merge branch 'develop' into roberta-model
dariadiatlova ecaa988
Merge branch 'develop' into roberta-model
nbirillo 982faa6
Fix merge conflicts
nbirillo c682448
moved model folder and added links to README.md
dariadiatlova 87c13c7
Merge remote-tracking branch 'origin/roberta-model' into roberta-model
dariadiatlova f76d52b
fixed merge conflicts
dariadiatlova c77e9c3
fixed line length
dariadiatlova e230057
added trailing comma
dariadiatlova 350d171
added description section to the readme.md
dariadiatlova 28d40a5
added description for batch_size and num_classes
dariadiatlova b0700e0
updated paths in readme
dariadiatlova 555d338
added ModelCommonArguments class
dariadiatlova 6a51681
fixed styles
dariadiatlova f6e3e74
split evaluation script into 2 functions
dariadiatlova a2b7a0d
added import for typing
dariadiatlova de3fdb0
extended whitelist
dariadiatlova 6b999f4
resolved arguments naming conflict
dariadiatlova 298fdc3
changed dtype in dataset dataframe
dariadiatlova 2f5cf3c
fixed type issue while reading dataset
dariadiatlova f3786f4
added examples to the readme
dariadiatlova 2baf137
extended description in readme.md
dariadiatlova 72ed326
added description to the MultilabelTrainer class
dariadiatlova 0e7be42
added description to the preprocessing module
dariadiatlova fcbf01d
added constant variable names for Measurer class
dariadiatlova 7be1da6
explained one-hot-encoding
dariadiatlova 98acc20
added docs to the context functions and specified return type
dariadiatlova 0f2128e
typo refactoring
dariadiatlova 92dc75c
added option to compute f1-score by classes
dariadiatlova 1b816be
custom metric debugging
dariadiatlova af0d8fb
custom metric debugging
dariadiatlova c1b496e
custom metric debugging
dariadiatlova 8a68089
custom metric debugging
dariadiatlova 473d623
resolve merge conflicts
dariadiatlova 2074e24
fixed line length
dariadiatlova 29cbb4d
added logger for f1 score
dariadiatlova 7200b11
added description
dariadiatlova fcc68d5
added new line import
dariadiatlova c4e2e54
added new return type
dariadiatlova ddee437
added option to save f1 scores to the file
dariadiatlova 20e9302
added description of the save metric option to the README.md
dariadiatlova 928e766
fixed file names for report
dariadiatlova 6458799
Update README.md
dariadiatlova 9ca35ac
added column with inspections indices
dariadiatlova 7e2ba8c
delete unnecessary requirements.txt
dariadiatlova f686b5c
fixed typing
dariadiatlova f244e41
name refactoring
dariadiatlova 5b7de2a
created separate class for Seed value
dariadiatlova File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| tqdm==4.49.0 | ||
| scikit-learn~=0.24.2 | ||
| transformers==4.6.1 | ||
| tokenizers==0.10.2 | ||
| torch==1.8.1 | ||
| wandb==0.10.31 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # Qodana imitation model | ||
| ## Description | ||
| The general purpose of the model is to simulate the behavior of [`Qodana`](https://github.com/JetBrains/Qodana/tree/main) – | ||
| a code quality monitoring tool that identifies and suggests fixes for bugs, security vulnerabilities, duplications, and imperfections. | ||
|
|
||
| Motivation for developing a model: | ||
| - acceleration of the code analysis process by training the model to recognize a certain class of errors; | ||
| - the ability to run the model on separate files without the need to create a project (for example, for the Java language) | ||
|
|
||
|
|
||
| ## Architecture | ||
| [`RobertaForSequenceClassification`](https://huggingface.co/transformers/model_doc/roberta.html#robertaforsequenceclassification) model with [`BCEWithLogitsLoss`](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html) solve multilabel classification task. | ||
|
|
||
| Model outputs is a tensor of size: `batch_size` x `num_classes`. Where `batch_size` is the number of training examples utilized in one iteration, | ||
| and `num_classes` is the number of error types met in the dataset. By model class here, we mean a unique error type. | ||
| Class probabilities are received by taking `sigmoid` and final predictions are computed by comparing the probability of each class with the `threshold`. | ||
|
|
||
| As classes might be unbalanced the used metric is `f1-score`. | ||
| ## What it does | ||
|
|
||
| Model has two use cases: | ||
| - It can be trained to predict a unique number of errors in a **block** of code, unfixed length. | ||
|
|
||
| **Example**: | ||
|
|
||
| code | inspections | ||
| --- | --- | ||
| |`import java.util.Scanner; class Main {public static void main(String[] args) {Scanner scanner = new Scanner(System.in);// put your code here int num = scanner.nextInt(); System.out.println((num / 10 ) % 10);}}`| 1, 2| | ||
|
|
||
|
|
||
| - It can be trained to predict a unique number of errors in a **line** of code. | ||
|
|
||
| **Example** | ||
|
|
||
| code | inspections | ||
| --- | --- | ||
| |`import java.util.Scanner;`| 0| | ||
| |`\n`|0| | ||
| |`class Main {`|1| | ||
| |`public static void main(String[] args`) {|1| | ||
| |`Scanner scanner = new Scanner(System.in);`|0| | ||
| |`// put your code here`|0| | ||
| |`int num = scanner.nextInt();`|0| | ||
| |`System.out.println((num / 10 ) % 10);`|2| | ||
| |`}`|0| | ||
| |`}`|0| | ||
|
|
||
|
|
||
| ## Data preprocessing | ||
|
|
||
| Please address to the [`following documentation`](src/python/evaluation/qodana) for labeling dataset and to the [`following documentation`](preprocessing) to preprocess data for model training and evaluation afterwards. | ||
|
|
||
| After completing the 3d preprocessing step you should have 3 folders: | ||
| `train`, `val`, `test` with `train.csv`, `val.csv` and `test.csv` respectively. | ||
|
|
||
| Each file has the same structure, it should consist of 4+ columns: | ||
| - `id` – solutions id; | ||
| - `code` – line od code or block of code; | ||
| - `lang` - language version; | ||
| - `0`, `1`, `2` ... `n` – several columns, equal to the unique number of errors detected by Qodana in the dataset. | ||
| The values in the columns are binary numbers: `1` if inspection is detected and `0` otherwise. | ||
|
|
||
|
|
||
| ## How to train the model | ||
|
|
||
| Run [`train.py`](train.py) script from the command line with the following arguments: | ||
|
|
||
| Required arguments: | ||
|
|
||
| - `train_dataset_path` ‑ path to the `train.csv` – file that consists of samples | ||
| that model will use for training. | ||
|
|
||
| - `val_dataset_path` ‑ path to the `val.csv` – file that consists of samples | ||
| that model will use for evaluation during training. | ||
|
|
||
| Both files are received by running [`split_dataset.py`](preprocessing/split_dataset.py) script and has the structure as described above. | ||
|
|
||
| Optional arguments: | ||
|
|
||
| Argument | Description | ||
| --- | --- | ||
| |**‑o**, **‑‑output_directory_path**| Path to the directory where model weights will be saved. If not set, folder will be created in the `train` folder where `train.csv` dataset is stored.| | ||
| |**‑c**, **‑‑context_length**| Sequence length or embedding size of tokenized samples. Available values are any `positive integers`. **Default is 40**.| | ||
| |**‑e**, **‑‑epoch**| Number of epochs to train model. **Default is 2**.| | ||
| |**‑bs**, **‑‑batch_size**| Batch size for training and validation dataset. Available values are any `positive integers`. **Default is 16**.| | ||
| |**‑lr**, **‑‑learning_rate**| Optimizer learning rate. **Default is 2e-5**.| | ||
| |**‑w**, **‑‑weight_decay**| Weight decay parameter for an optimizer. **Default is 0.01**.| | ||
| |**‑th**, **‑‑threshold**| Is used to compute predictions. Available values: 0 < `threshold` < 1. If the probability of inspection is greater than `threshold`, sample will be classified with the inspection. **Default is 0.5**.| | ||
| |**‑ws**, **‑‑warm_up_steps**| A number of steps when optimizer uses constant learning rate before applying scheduler policy. **Default is 300**.| | ||
| |**‑sl**, **‑‑save_limit**| Total amount of checkpoints limit. Default is 1.| | ||
|
|
||
| To inspect the rest of default training parameters please, address to the [`TrainingArguments`](common/train_config.py). | ||
|
|
||
| ## How to evaluate model | ||
|
|
||
| Run [`evaluation.py`](evaluation.py) script from the command line with the following arguments: | ||
|
|
||
| Required arguments: | ||
|
|
||
| `test_dataset_path` ‑ path to the `test.csv` received by running [`split_dataset.py`](preprocessing/split_dataset.py) script. | ||
|
|
||
| `model_weights_directory_path` ‑ path to the folder where trained model weights are saved. | ||
|
|
||
| Optional arguments: | ||
|
|
||
| Argument | Description | ||
| --- | --- | ||
| |**‑o**, **‑‑output_directory_path**| Path to the directory where labeled dataset will be saved. Default is the `test` folder.| | ||
| |**‑c**, **‑‑context_length**| Sequence length or embedding size of tokenized samples. Available values are any `positive integers`. **Default is 40**.| | ||
| |**‑sf**, **‑‑save_f1_score**| If enabled report with f1 scores by classes will be saved to the `csv` file in the parent directory of labeled dataset. **Disabled by default**.| | ||
| |**‑bs**, **‑‑batch_size**| The number of training examples utilized in one training and validation iteration. Available values are any `positive integers`. **Default is 16**.| | ||
| |**‑th**, **‑‑threshold**| Is used to compute predictions. Available values: 0 < `threshold` < 1. If the probability of inspection is greater than `threshold`, sample will be classified with the inspection. **Default is 0.5**.| | ||
|
|
||
| Output is a `predictions.csv` file with the column names matches the number of classes. Each sample has a binary label: | ||
|
|
||
| - `0` ‑ if the model didn't found an error in a sample. | ||
|
|
||
| - `1` ‑ if the error was found in a sample. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from src.python import MAIN_FOLDER | ||
|
|
||
| MODEL_FOLDER = MAIN_FOLDER.parent / 'python/imitation_model' |
Empty file.
47 changes: 47 additions & 0 deletions
47
src/python/evaluation/qodana/imitation_model/common/evaluation_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import argparse | ||
|
|
||
| from src.python.evaluation.qodana.imitation_model.common.util import ModelCommonArgument | ||
| from src.python.review.common.file_system import Extension | ||
|
|
||
|
|
||
| def configure_arguments(parser: argparse.ArgumentParser) -> None: | ||
| parser.add_argument('test_dataset_path', | ||
| type=str, | ||
| help='Path to the dataset received by either' | ||
| f' src.python.evaluation.qodana.fragment_to_inspections_list{Extension.PY.value}' | ||
| 'or src.python.evaluation.qodana.fragment_to_inspections_list_line_by_line' | ||
| f'{Extension.PY.value}script.') | ||
|
|
||
| parser.add_argument('model_weights_directory_path', | ||
| type=str, | ||
| help='Path to the directory where trained imitation_model weights are stored.') | ||
|
|
||
| parser.add_argument('-o', '--output_directory_path', | ||
| default=None, | ||
| type=str, | ||
| help='Path to the directory where labeled dataset will be saved. Default is the parent folder' | ||
| 'of test_dataset_path.') | ||
|
|
||
| parser.add_argument('-sf', '--save_f1_score', | ||
| default=None, | ||
| action="store_true", | ||
| help=f'If enabled report with f1 scores by class will be saved to the {Extension.CSV.value}' | ||
| ' File will be saved to the labeled dataset parent directory. Default is False.') | ||
|
|
||
| parser.add_argument(ModelCommonArgument.CONTEXT_LENGTH.value.short_name, | ||
| ModelCommonArgument.CONTEXT_LENGTH.value.long_name, | ||
| type=int, | ||
| default=40, | ||
| help=ModelCommonArgument.CONTEXT_LENGTH.value.description) | ||
|
|
||
| parser.add_argument(ModelCommonArgument.BATCH_SIZE.value.short_name, | ||
| ModelCommonArgument.BATCH_SIZE.value.long_name, | ||
| type=int, | ||
| default=8, | ||
| help=ModelCommonArgument.BATCH_SIZE.value.description) | ||
|
|
||
| parser.add_argument(ModelCommonArgument.THRESHOLD.value.short_name, | ||
| ModelCommonArgument.THRESHOLD.value.long_name, | ||
| type=float, | ||
| default=0.5, | ||
| help=ModelCommonArgument.THRESHOLD.value.description) |
41 changes: 41 additions & 0 deletions
41
src/python/evaluation/qodana/imitation_model/common/metric.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| import logging.config | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from sklearn.metrics import multilabel_confusion_matrix | ||
| from src.python.evaluation.qodana.imitation_model.common.util import MeasurerArgument | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class Measurer: | ||
| def __init__(self, threshold: float): | ||
| self.threshold = threshold | ||
|
|
||
| def get_f1_score(self, predictions: torch.tensor, targets: torch.tensor) -> Optional[float]: | ||
| confusion_matrix = multilabel_confusion_matrix(targets, predictions) | ||
| false_positives = sum(score[0][1] for score in confusion_matrix) | ||
| false_negatives = sum(score[1][0] for score in confusion_matrix) | ||
| true_positives = sum(score[1][1] for score in confusion_matrix) | ||
| try: | ||
| f1_score = true_positives / (true_positives + 1 / 2 * (false_positives + false_negatives)) | ||
| return f1_score | ||
| except ZeroDivisionError: | ||
| logger.error("No values of the class present in the dataset.") | ||
| # return None to make it clear after printing what classes are missing in the datasets | ||
| return None | ||
|
|
||
| def compute_metric(self, evaluation_predictions: torch.tensor) -> dict: | ||
| logits, targets = evaluation_predictions | ||
| prediction_probabilities = torch.from_numpy(logits).sigmoid() | ||
| predictions = torch.where(prediction_probabilities > self.threshold, 1, 0) | ||
| return {MeasurerArgument.F1_SCORE.value: self.get_f1_score(predictions, torch.tensor(targets))} | ||
|
|
||
| def f1_score_by_classes(self, predictions: torch.tensor, targets: torch.tensor) -> dict: | ||
| unique_classes = range(len(targets[0])) | ||
| f1_scores_by_classes = {} | ||
| for unique_class in unique_classes: | ||
| class_mask = torch.where(targets[:, unique_class] == 1) | ||
| f1_scores_by_classes[str(unique_class)] = self.get_f1_score(predictions[class_mask[0], unique_class], | ||
| targets[class_mask[0], unique_class]) | ||
| return f1_scores_by_classes |
118 changes: 118 additions & 0 deletions
118
src/python/evaluation/qodana/imitation_model/common/train_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| import argparse | ||
|
|
||
| import torch | ||
| from src.python.evaluation.qodana.imitation_model.common.util import ( | ||
| DatasetColumnArgument, | ||
| ModelCommonArgument, | ||
| SeedArgument, | ||
| ) | ||
| from transformers import Trainer, TrainingArguments | ||
|
|
||
|
|
||
| class MultilabelTrainer(Trainer): | ||
| """ By default RobertaForSequence classification does not support | ||
| multi-label classification. | ||
|
|
||
| Target and logits tensors should be represented as torch.FloatTensor of shape (1,). | ||
| https://huggingface.co/transformers/model_doc/roberta.html#transformers.RobertaForSequenceClassification | ||
|
|
||
| To fine-tune the model for the multi-label classification task we can simply modify the trainer by | ||
| changing its loss function. https://huggingface.co/transformers/main_classes/trainer.html | ||
| """ | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False): | ||
| labels = inputs.pop(DatasetColumnArgument.LABELS.value) | ||
| outputs = model(**inputs) | ||
| logits = outputs.logits | ||
| loss_bce = torch.nn.BCEWithLogitsLoss() | ||
| loss = loss_bce(logits.view(-1, self.model.config.num_labels), | ||
| labels.float().view(-1, self.model.config.num_labels)) | ||
|
|
||
| return (loss, outputs) if return_outputs else loss | ||
|
|
||
|
|
||
| def configure_arguments(parser: argparse.ArgumentParser) -> None: | ||
| parser.add_argument('train_dataset_path', | ||
| type=str, | ||
| help='Path to the train dataset.') | ||
|
|
||
| parser.add_argument('val_dataset_path', | ||
| type=str, | ||
| help='Path to the dataset received by either') | ||
|
|
||
| parser.add_argument('-wp', '--trained_weights_directory_path', | ||
| default=None, | ||
| type=str, | ||
| help='Path to the directory where to save imitation_model weights. Default is the directory' | ||
| 'where train dataset is.') | ||
|
|
||
| parser.add_argument(ModelCommonArgument.CONTEXT_LENGTH.value.short_name, | ||
| ModelCommonArgument.CONTEXT_LENGTH.value.long_name, | ||
| type=int, | ||
| default=40, | ||
| help=ModelCommonArgument.CONTEXT_LENGTH.value.description) | ||
|
|
||
| parser.add_argument(ModelCommonArgument.BATCH_SIZE.value.short_name, | ||
| ModelCommonArgument.BATCH_SIZE.value.long_name, | ||
| type=int, | ||
| default=16, | ||
| help=ModelCommonArgument.BATCH_SIZE.value.description) | ||
|
|
||
| parser.add_argument(ModelCommonArgument.THRESHOLD.value.short_name, | ||
| ModelCommonArgument.THRESHOLD.value.long_name, | ||
| type=float, | ||
| default=0.5, | ||
| help=ModelCommonArgument.THRESHOLD.value.description) | ||
|
|
||
| parser.add_argument('-lr', '--learning_rate', | ||
| type=int, | ||
| default=2e-5, | ||
| help='Learning rate.') | ||
|
|
||
| parser.add_argument('-wd', '--weight_decay', | ||
| type=int, | ||
| default=0.01, | ||
| help='Wight decay parameter for optimizer.') | ||
|
|
||
| parser.add_argument('-e', '--epoch', | ||
| type=int, | ||
| default=1, | ||
| help='Number of epochs to train imitation_model.') | ||
|
|
||
| parser.add_argument('-ws', '--warm_up_steps', | ||
| type=int, | ||
| default=300, | ||
| help='Number of steps used for a linear warmup, default is 300.') | ||
|
|
||
| parser.add_argument('-sl', '--save_limit', | ||
| type=int, | ||
| default=1, | ||
| help='Total amount of checkpoints limit. Default is 1.') | ||
|
|
||
|
|
||
| class TrainingArgs: | ||
| def __init__(self, args): | ||
| self.args = args | ||
|
|
||
| def get_training_args(self, val_steps_to_be_made): | ||
| return TrainingArguments(num_train_epochs=self.args.epoch, | ||
| per_device_train_batch_size=self.args.batch_size, | ||
| per_device_eval_batch_size=self.args.batch_size, | ||
| learning_rate=self.args.learning_rate, | ||
| warmup_steps=self.args.warm_up_steps, | ||
| weight_decay=self.args.weight_decay, | ||
| save_total_limit=self.args.save_limit, | ||
| output_dir=self.args.trained_weights_directory_path, | ||
| overwrite_output_dir=True, | ||
| load_best_model_at_end=True, | ||
| greater_is_better=True, | ||
| save_steps=val_steps_to_be_made, | ||
| eval_steps=val_steps_to_be_made, | ||
| logging_steps=val_steps_to_be_made, | ||
| evaluation_strategy=DatasetColumnArgument.STEPS.value, | ||
| logging_strategy=DatasetColumnArgument.STEPS.value, | ||
| seed=SeedArgument.SEED.value, | ||
| report_to=[DatasetColumnArgument.WANDB.value]) |
46 changes: 46 additions & 0 deletions
46
src/python/evaluation/qodana/imitation_model/common/util.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| from enum import Enum, unique | ||
|
|
||
| from src.python.common.tool_arguments import ArgumentsInfo | ||
|
|
||
|
|
||
| @unique | ||
| class DatasetColumnArgument(Enum): | ||
| ID = 'id' | ||
| IN_ID = 'inspection_id' | ||
| INSPECTIONS = 'inspections' | ||
| INPUT_IDS = 'input_ids' | ||
| LABELS = 'labels' | ||
| DATASET_PATH = 'dataset_path' | ||
| STEPS = 'steps' | ||
| WEIGHTS = 'weights' | ||
| WANDB = 'wandb' | ||
|
|
||
|
|
||
| @unique | ||
| class SeedArgument(Enum): | ||
| SEED = 42 | ||
|
|
||
|
|
||
| @unique | ||
| class CustomTokens(Enum): | ||
| NOC = '[NOC]' # no context token to add when there are no lines for the context | ||
|
|
||
|
|
||
| @unique | ||
| class ModelCommonArgument(Enum): | ||
| THRESHOLD = ArgumentsInfo('-th', '--threshold', | ||
| 'If the probability of inspection on code sample is greater than threshold,' | ||
| 'inspection id will be assigned to the sample. ' | ||
| 'Default is 0.5.') | ||
|
|
||
| CONTEXT_LENGTH = ArgumentsInfo('-cl', '--context_length', | ||
| 'Sequence length of 1 sample after tokenization, default is 40.') | ||
|
|
||
| BATCH_SIZE = ArgumentsInfo('-bs', '--batch_size', | ||
| 'Batch size – default values are 16 for training and 8 for evaluation mode.') | ||
|
|
||
|
|
||
| @unique | ||
| class MeasurerArgument(Enum): | ||
| F1_SCORE = 'f1_score' | ||
| F1_SCORES_BY_CLS = 'f1_scores_by_class' | ||
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.