From 06efb74c97d065828f197709d1769f14db69c349 Mon Sep 17 00:00:00 2001 From: Ali Tarik Date: Thu, 17 Aug 2023 19:22:46 +0300 Subject: [PATCH 1/2] support for string ner_tags column --- langtest/datahandler/datasource.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 5aedab62c..1d4466a45 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -946,15 +946,23 @@ def load_data_ner( else: dataset = self.load_dataset(self.dataset_name, split=split) - label_names = dataset.features[target_column].feature.names - - dataset = map( - lambda example: { - "tokens": example[feature_column], - "ner_tags": [label_names[x] for x in example[target_column]], - }, - dataset, - ) + if "label" in str(type(dataset.features[target_column].feature)): + label_names = dataset.features[target_column].feature.names + dataset = map( + lambda example: { + "tokens": example[feature_column], + "ner_tags": [label_names[x] for x in example[target_column]], + }, + dataset, + ) + else: + dataset = map( + lambda example: { + "tokens": example[feature_column], + "ner_tags": example[target_column], + }, + dataset, + ) samples = [self._row_to_ner_sample(example) for example in dataset] return samples From 91a1d1028c79ce412c398e4b1d5eaab334e7aee1 Mon Sep 17 00:00:00 2001 From: Ali Tarik Date: Thu, 17 Aug 2023 19:53:20 +0300 Subject: [PATCH 2/2] add tests --- tests/test_datasource.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 15552b53b..63705195c 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -80,6 +80,14 @@ def test_load_raw_data(self, dataset, feature_col, target_col): "split": "test", }, ), + ( + HuggingFaceDataset(dataset_name="Prikshit7766/12", task="ner"), + { + "feature_column": "tokens", + "target_column": "ner_tags", + "split": "test", + }, + ), (CSVDataset(file_path="tests/fixtures/tner.csv", task="ner"), {}), (ConllDataset(file_path="tests/fixtures/test.conll", task="ner"), {}), ],