From 1cfe4c3dd867f272010c22a5a9f30d77d8aa37dc Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 11:24:06 +0100 Subject: [PATCH 01/10] test: recalculate exapmes tests --- tests/test_codebleu.py | 50 +++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index 9bf4f32..ee46ac5 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -15,16 +15,16 @@ ["def test ( ) :\n pass"], 0.25, ), # 'cause data_flow is 0 and considered as 1 - (["def bar ( y , x ) :\n a = x * x\n return a"], ["def foo ( x ) :\n return x"], 0.4), - (["def foo ( x ) :\n return x * x"], ["def bar ( x ) :\n return x"], 0.6), - (["def bar ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 0.8), + (["def bar ( y , x ) :\n a = x * x\n return a"], ["def foo ( x ) :\n return x"], 0.38), + (["def foo ( x ) :\n return x * x"], ["def bar ( x ) :\n return x"], 0.61), + (["def bar ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 0.85), (["def foo ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 1.0), ], ) def test_simple_cases(predictions: List[Any], references: List[Any], codebleu: float) -> None: result = calc_codebleu(references, predictions, "python") logging.debug(result) - assert result["codebleu"] == pytest.approx(codebleu, 0.1) + assert result["codebleu"] == pytest.approx(codebleu, 0.01) @pytest.mark.parametrize(["lang"], [(lang,) for lang in AVAILABLE_LANGS]) @@ -48,7 +48,7 @@ def test_exact_match_works_for_all_langs(lang: str) -> None: def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None: result = calc_codebleu(references, predictions, lang) logging.debug(result) - assert result["codebleu"] == pytest.approx(0.6, 0.1) + assert result["codebleu"] == pytest.approx(0.6, 0.05) def test_error_when_lang_not_supported() -> None: @@ -61,25 +61,39 @@ def test_error_when_input_length_mismatch() -> None: calc_codebleu(["def foo : pass"], ["def bar : pass", "def buz : pass"], "python") -# https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/example.png @pytest.mark.parametrize( - ["predictions", "references", "codebleu"], + ["predictions", "references", "bleu", "codebleu"], [ - # ( - # ['public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }'], - # ['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'], - # 0.7238 # TODO: lol, not working at <3.12 - # ), - # ( - # ['public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }'], - # ['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'], - # 0.8397 # TODO: check, lol, not working - # ), + # https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/example.png + ( + ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }"], + ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], + 0.7846, + 0.7238, # TODO: lol, not working at <3.12 + ), + # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 + ( + ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ;"], + ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], + 0.7543, + 0.7091, # Should be 0.6973 if AST=13/21, however at the moment tee-sitter AST is 14/21 + ), + # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 + ( + ["public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }"], + ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], + 0.7571, # Error in the Figure 4, text "Example 2" states 0.7571, not 0.6814, + 0.8804, # Error in the Figure 4, text "Example 2" states 0.8804, not 0.8397, + ), ], ) -def test_code_x_glue_readme_examples(predictions: List[Any], references: List[Any], codebleu: float) -> None: +def test_code_x_glue_readme_examples( + predictions: List[Any], references: List[Any], bleu: float, codebleu: float +) -> None: result = calc_codebleu(references, predictions, "java") logging.debug(result) + + assert result["ngram_match_score"] == pytest.approx(bleu, 0.01) assert result["codebleu"] == pytest.approx(codebleu, 0.01) From 17411d4b94f18e0abd3e1b8adaeac3483c49c3bd Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 11:25:37 +0100 Subject: [PATCH 02/10] style: apply black and add some comments --- codebleu/bleu.py | 5 ++++- codebleu/dataflow_match.py | 4 ++-- codebleu/parser/build.py | 30 +++++++++++++++--------------- codebleu/syntax_match.py | 25 +++++++++++++++---------- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/codebleu/bleu.py b/codebleu/bleu.py index be0738c..f528746 100644 --- a/codebleu/bleu.py +++ b/codebleu/bleu.py @@ -18,8 +18,11 @@ from .utils import ngrams -# _normalize=False was removed in 3.12, add custom class for back-compatibility class Fraction(_Fraction): + """Fraction class with _normalize=False support. + _normalize=False was removed in 3.12, add custom class for back-compatibility + """ + # We're immutable, so use __new__ not __init__ def __new__(cls, numerator: Any = 0, denominator: Any = None, *, _normalize: bool = True) -> "Fraction": if sys.version_info >= (3, 12): diff --git a/codebleu/dataflow_match.py b/codebleu/dataflow_match.py index 30e3871..a4dd6d4 100644 --- a/codebleu/dataflow_match.py +++ b/codebleu/dataflow_match.py @@ -47,11 +47,11 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file): candidate = candidates[i] for reference in references_sample: try: - candidate = remove_comments_and_docstrings(candidate, "java") + candidate = remove_comments_and_docstrings(candidate, lang) except Exception: pass try: - reference = remove_comments_and_docstrings(reference, "java") + reference = remove_comments_and_docstrings(reference, lang) except Exception: pass diff --git a/codebleu/parser/build.py b/codebleu/parser/build.py index c1f1a3a..3e409e2 100644 --- a/codebleu/parser/build.py +++ b/codebleu/parser/build.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft Corporation. # Copyright (c) 2023 Konstantin Chernyshev. # Licensed under the MIT license. - from tree_sitter import Language -Language.build_library( - "my-languages.so", - [ - "tree-sitter/go", - "tree-sitter/javascript", - "tree-sitter/python", - "tree-sitter/php", - "tree-sitter/java", - "tree-sitter/ruby", - "tree-sitter/c-sharp", - "tree-sitter/c", - "tree-sitter/cpp", - ], -) +if __name__ == "__main__": + Language.build_library( + "my-languages.so", + [ + "tree-sitter/go", + "tree-sitter/javascript", + "tree-sitter/python", + "tree-sitter/php", + "tree-sitter/java", + "tree-sitter/ruby", + "tree-sitter/c-sharp", + "tree-sitter/c", + "tree-sitter/cpp", + ], + ) diff --git a/codebleu/syntax_match.py b/codebleu/syntax_match.py index 5f14e76..92dd6db 100644 --- a/codebleu/syntax_match.py +++ b/codebleu/syntax_match.py @@ -30,11 +30,11 @@ def calc_syntax_match(references, candidate, lang, lang_so_file): def corpus_syntax_match(references, candidates, lang, lang_so_file): - # print(os.listdir()) - JAVA_LANGUAGE = Language(lang_so_file, lang) + tree_sitter_language = Language(lang_so_file, lang) parser = Parser() - parser.set_language(JAVA_LANGUAGE) + parser.set_language(tree_sitter_language) match_count = 0 + match_count_candidate_to_reference = 0 total_count = 0 for i in range(len(candidates)): @@ -42,11 +42,11 @@ def corpus_syntax_match(references, candidates, lang, lang_so_file): candidate = candidates[i] for reference in references_sample: try: - candidate = remove_comments_and_docstrings(candidate, "java") + candidate = remove_comments_and_docstrings(candidate, lang) except Exception: pass try: - reference = remove_comments_and_docstrings(reference, "java") + reference = remove_comments_and_docstrings(reference, lang) except Exception: pass @@ -69,14 +69,19 @@ def get_all_sub_trees(root_node): return sub_tree_sexp_list cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] - ref_sexps = get_all_sub_trees(reference_tree) + ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)] - # print(cand_sexps) - # print(ref_sexps) - - for sub_tree, depth in ref_sexps: + # TODO: fix, now we count number of reference subtrees matching candidate, + # but we should count number of candidate subtrees matching reference + # See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf + for sub_tree in ref_sexps: if sub_tree in cand_sexps: match_count += 1 + + for sub_tree in cand_sexps: + if sub_tree in ref_sexps: + match_count_candidate_to_reference += 1 + total_count += len(ref_sexps) score = match_count / total_count From 4751c4677fba7a4a009f4521526775801ed942e3 Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 11:36:02 +0100 Subject: [PATCH 03/10] ci: fast test on 3.12 --- .github/workflows/test.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9c256e1..aa36d11 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,7 +39,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.12' cache: 'pip' # caching pip dependencies - name: Install dependencies run: | diff --git a/pyproject.toml b/pyproject.toml index ddc4a28..371d2da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ skip = ["build", "dist", ".venv", ".eggs", ".mypy_cache", ".pytest_cache", ".git [tool.black] line_length=120 -target_version=["py38","py39","py310","py311"] +target_version=["py38","py39","py310","py311", "py312"] [tool.ruff] line-length=120 From 60529888da977b9a677aec6014273def326145bd Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 11:42:14 +0100 Subject: [PATCH 04/10] test: add AST and data to examples --- tests/test_codebleu.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index ee46ac5..e33ad8a 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -62,13 +62,15 @@ def test_error_when_input_length_mismatch() -> None: @pytest.mark.parametrize( - ["predictions", "references", "bleu", "codebleu"], + ["predictions", "references", "bleu", "syntax_match", "dataflow_match", "codebleu"], [ # https://github.com/microsoft/CodeXGLUE/blob/main/Code-Code/code-to-code-trans/example.png ( ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7846, + 14/21, + 2/3, 0.7238, # TODO: lol, not working at <3.12 ), # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 @@ -76,6 +78,8 @@ def test_error_when_input_length_mismatch() -> None: ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ;"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7543, + 14/21, + 2/3, 0.7091, # Should be 0.6973 if AST=13/21, however at the moment tee-sitter AST is 14/21 ), # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 @@ -83,17 +87,26 @@ def test_error_when_input_length_mismatch() -> None: ["public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7571, # Error in the Figure 4, text "Example 2" states 0.7571, not 0.6814, + 1.0, + 1.0, 0.8804, # Error in the Figure 4, text "Example 2" states 0.8804, not 0.8397, ), ], ) def test_code_x_glue_readme_examples( - predictions: List[Any], references: List[Any], bleu: float, codebleu: float + predictions: List[Any], + references: List[Any], + bleu: float, + syntax_match: float, + dataflow_match: float, + codebleu: float, ) -> None: result = calc_codebleu(references, predictions, "java") logging.debug(result) assert result["ngram_match_score"] == pytest.approx(bleu, 0.01) + assert result["syntax_match_score"] == pytest.approx(syntax_match, 0.01) + assert result["dataflow_match_score"] == pytest.approx(dataflow_match, 0.01) assert result["codebleu"] == pytest.approx(codebleu, 0.01) From 8d179611863b01db11854725d26bffb1cdcb717c Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 12:45:05 +0100 Subject: [PATCH 05/10] fix: add debug print in syntax match --- codebleu/syntax_match.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/codebleu/syntax_match.py b/codebleu/syntax_match.py index 92dd6db..802e495 100644 --- a/codebleu/syntax_match.py +++ b/codebleu/syntax_match.py @@ -71,6 +71,13 @@ def get_all_sub_trees(root_node): cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)] + print('cand_sexps') + for tree, depth in get_all_sub_trees(candidate_tree): + print(' ', depth, tree) + print('ref_sexps') + for tree, depth in get_all_sub_trees(reference_tree): + print(' ', depth, tree) + # TODO: fix, now we count number of reference subtrees matching candidate, # but we should count number of candidate subtrees matching reference # See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf @@ -83,6 +90,7 @@ def get_all_sub_trees(root_node): match_count_candidate_to_reference += 1 total_count += len(ref_sexps) - + print(f'match_count {match_count} / {total_count}') + print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}') score = match_count / total_count return score From eceb842982c30111a6490e6ee70914957044519d Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 14:35:22 +0100 Subject: [PATCH 06/10] test: fix test output with new version of tree-sitter --- codebleu/syntax_match.py | 11 ++--------- tests/test_codebleu.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/codebleu/syntax_match.py b/codebleu/syntax_match.py index 802e495..0050c1a 100644 --- a/codebleu/syntax_match.py +++ b/codebleu/syntax_match.py @@ -71,13 +71,6 @@ def get_all_sub_trees(root_node): cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] ref_sexps = [x[0] for x in get_all_sub_trees(reference_tree)] - print('cand_sexps') - for tree, depth in get_all_sub_trees(candidate_tree): - print(' ', depth, tree) - print('ref_sexps') - for tree, depth in get_all_sub_trees(reference_tree): - print(' ', depth, tree) - # TODO: fix, now we count number of reference subtrees matching candidate, # but we should count number of candidate subtrees matching reference # See (4) in "3.2 Syntactic AST Match" of https://arxiv.org/pdf/2009.10297.pdf @@ -90,7 +83,7 @@ def get_all_sub_trees(root_node): match_count_candidate_to_reference += 1 total_count += len(ref_sexps) - print(f'match_count {match_count} / {total_count}') - print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}') + # print(f'match_count {match_count} / {total_count}') + # print(f'match_count_fixed {match_count_candidate_to_reference} / {total_count}') score = match_count / total_count return score diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index e33ad8a..2e1678a 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -69,18 +69,18 @@ def test_error_when_input_length_mismatch() -> None: ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7846, - 14/21, + 11/19, # In example, it is 13/21, but with new version of tree-sitter it is 11/19 2/3, - 0.7238, # TODO: lol, not working at <3.12 + 0.7019, # Should be 0.7238 if AST=13/21 in the paper, however at the moment tee-sitter AST is 11/19 ), # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 ( ["public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ;"], ["public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }"], 0.7543, - 14/21, + 11/19, # In example, it is 13/21, but with new version of tree-sitter it is 11/19 2/3, - 0.7091, # Should be 0.6973 if AST=13/21, however at the moment tee-sitter AST is 14/21 + 0.6873, # Should be 0.6973 if AST=13/21 in the paper, however at the moment tee-sitter AST is 11/19 ), # https://arxiv.org/pdf/2009.10297.pdf "3.4 Two Examples" at the page 4 ( @@ -104,11 +104,15 @@ def test_code_x_glue_readme_examples( result = calc_codebleu(references, predictions, "java") logging.debug(result) + print(result) + assert result["ngram_match_score"] == pytest.approx(bleu, 0.01) assert result["syntax_match_score"] == pytest.approx(syntax_match, 0.01) assert result["dataflow_match_score"] == pytest.approx(dataflow_match, 0.01) assert result["codebleu"] == pytest.approx(codebleu, 0.01) + # assert False + @pytest.mark.parametrize( ["predictions", "references", "codebleu"], From d0f48addb50afadd9182609cfc5704d1a94306e8 Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 16:24:43 +0100 Subject: [PATCH 07/10] test: debug fail first test --- tests/test_codebleu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index 2e1678a..6422d51 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -24,6 +24,8 @@ def test_simple_cases(predictions: List[Any], references: List[Any], codebleu: float) -> None: result = calc_codebleu(references, predictions, "python") logging.debug(result) + print(result) + assert False assert result["codebleu"] == pytest.approx(codebleu, 0.01) From 4deda6ace0771d53d2afba8ae3e302a9985e2ecf Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 17:28:38 +0100 Subject: [PATCH 08/10] ci: tmp true fast tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aa36d11..949d9fe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,7 +45,7 @@ jobs: run: | python -m pip install -e .[test] - name: Run tests - run: python -m pytest + run: python -m pytest | true external-build-workflow: needs: [fast-tests-python] From d95b1d15b68a5954b69d04e4104054f88490926e Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 18:33:00 +0100 Subject: [PATCH 09/10] refactor: bleu file to work with 3.12 --- .github/workflows/test.yml | 2 +- codebleu/bleu.py | 156 ++----------------------------- codebleu/weighted_ngram_match.py | 129 +------------------------ pyproject.toml | 5 +- tests/test_codebleu.py | 10 +- 5 files changed, 15 insertions(+), 287 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 949d9fe..aa36d11 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,7 +45,7 @@ jobs: run: | python -m pip install -e .[test] - name: Run tests - run: python -m pytest | true + run: python -m pytest external-build-workflow: needs: [fast-tests-python] diff --git a/codebleu/bleu.py b/codebleu/bleu.py index f528746..8db2e7a 100644 --- a/codebleu/bleu.py +++ b/codebleu/bleu.py @@ -12,25 +12,11 @@ import sys import warnings from collections import Counter -from fractions import Fraction as _Fraction from typing import Any from .utils import ngrams -class Fraction(_Fraction): - """Fraction class with _normalize=False support. - _normalize=False was removed in 3.12, add custom class for back-compatibility - """ - - # We're immutable, so use __new__ not __init__ - def __new__(cls, numerator: Any = 0, denominator: Any = None, *, _normalize: bool = True) -> "Fraction": - if sys.version_info >= (3, 12): - return super(Fraction, cls).__new__(cls, numerator, denominator) - else: - return super(Fraction, cls).__new__(cls, numerator, denominator, _normalize=False) - - def sentence_bleu( references, hypothesis, @@ -166,9 +152,9 @@ def corpus_bleu( # For each order of ngram, calculate the numerator and # denominator for the corpus-level modified precision. for i, _ in enumerate(weights, start=1): - p_i = modified_precision(references, hypothesis, i) - p_numerators[i] += p_i.numerator - p_denominators[i] += p_i.denominator + p_i_numerator, p_i_denominator = modified_precision(references, hypothesis, i) + p_numerators[i] += p_i_numerator + p_denominators[i] += p_i_denominator # Calculate the hypothesis length and the closest reference length. # Adds them to the corpus-level hypothesis and reference counts. @@ -185,8 +171,8 @@ def corpus_bleu( if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): weights = (1 / hyp_lengths,) * hyp_lengths - # Collects the various precision values for the different ngram orders. - p_n = [Fraction(p_numerators[i], p_denominators[i], _normalize=False) for i, _ in enumerate(weights, start=1)] + # Collects the various recall values for the different ngram orders. + p_n = [(p_numerators[i], p_denominators[i]) for i, _ in enumerate(weights, start=1)] # Returns 0 if there's no matching n-grams # We only need to check for p_numerators[1] == 0, since if there's @@ -202,7 +188,7 @@ def corpus_bleu( # it tries to retain the Fraction object as much as the # smoothing method allows. p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths) - s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n)) + s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n)) s = bp * math.exp(math.fsum(s)) return s @@ -298,7 +284,8 @@ def modified_precision(references, hypothesis, n): # Usually this happens when the ngram order is > len(reference). denominator = max(1, sum(counts.values())) - return Fraction(numerator, denominator, _normalize=False) + # return Fraction(numerator, denominator, _normalize=False) + return numerator, denominator def closest_ref_length(references, hyp_len): @@ -447,133 +434,8 @@ def __init__(self, epsilon=0.1, alpha=5, k=5): self.alpha = alpha self.k = k - def method0(self, p_n, *args, **kwargs): - """ - No smoothing. - """ - p_n_new = [] - for i, p_i in enumerate(p_n): - if p_i.numerator != 0: - p_n_new.append(p_i) - else: - _msg = str( - "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" - "Therefore the BLEU score evaluates to 0, independently of\n" - "how many N-gram overlaps of lower order it contains.\n" - "Consider using lower n-gram order or use " - "SmoothingFunction()" - ).format(i + 1) - warnings.warn(_msg) - # When numerator==0 where denonminator==0 or !=0, the result - # for the precision score should be equal to 0 or undefined. - # Due to BLEU geometric mean computation in logarithm space, - # we we need to take the return sys.float_info.min such that - # math.log(sys.float_info.min) returns a 0 precision score. - p_n_new.append(sys.float_info.min) - return p_n_new - def method1(self, p_n, *args, **kwargs): """ Smoothing method 1: Add *epsilon* counts to precision with 0 counts. """ - return [(p_i.numerator + self.epsilon) / p_i.denominator if p_i.numerator == 0 else p_i for p_i in p_n] - - def method2(self, p_n, *args, **kwargs): - """ - Smoothing method 2: Add 1 to both numerator and denominator from - Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of - machine translation quality using longest common subsequence and - skip-bigram statistics. In ACL04. - """ - return [Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False) for p_i in p_n] - - def method3(self, p_n, *args, **kwargs): - """ - Smoothing method 3: NIST geometric sequence smoothing - The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each - precision score whose matching n-gram count is null. - k is 1 for the first 'n' value for which the n-gram match count is null/ - For example, if the text contains: - - one 2-gram match - - and (consequently) two 1-gram matches - the n-gram count for each individual precision score would be: - - n=1 => prec_count = 2 (two unigrams) - - n=2 => prec_count = 1 (one bigram) - - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) - - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) - """ - incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. - for i, p_i in enumerate(p_n): - if p_i.numerator == 0: - p_n[i] = 1 / (2**incvnt * p_i.denominator) - incvnt += 1 - return p_n - - def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 4: - Shorter translations may have inflated precision values due to having - smaller denominators; therefore, we give them proportionally - smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry - suggests dividing by 1/ln(len(T)), where T is the length of the translation. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - for i, p_i in enumerate(p_n): - if p_i.numerator == 0 and hyp_len != 0: - incvnt = i + 1 * self.k / math.log(hyp_len) # Note that this K is different from the K from NIST. - p_n[i] = incvnt / p_i.denominator - return p_n - - def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 5: - The matched counts for similar values of n should be similar. To a - calculate the n-gram matched count, it averages the n−1, n and n+1 gram - matched counts. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - m = {} - # Requires an precision value for an addition ngram order. - p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] - m[-1] = p_n[0] + 1 - for i, p_i in enumerate(p_n): - p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 - m[i] = p_n[i] - return p_n - - def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 6: - Interpolates the maximum likelihood estimate of the precision *p_n* with - a prior estimate *pi0*. The prior is estimated by assuming that the ratio - between pn and pn−1 will be the same as that between pn−1 and pn−2; from - Gao and He (2013) Training MRF-Based Phrase Translation Models using - Gradient Ascent. In NAACL. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - # This smoothing only works when p_1 and p_2 is non-zero. - # Raise an error with an appropriate message when the input is too short - # to use this smoothing technique. - assert p_n[2], "This smoothing method requires non-zero precision for bigrams." - for i, p_i in enumerate(p_n): - if i in [0, 1]: # Skips the first 2 orders of ngrams. - continue - else: - pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] - # No. of ngrams in translation that matches the reference. - m = p_i.numerator - # No. of ngrams in translation. - ngrams_count = sum(1 for _ in ngrams(hypothesis, i + 1)) - # Calculates the interpolated precision. - p_n[i] = (m + self.alpha * pi0) / (ngrams_count + self.alpha) - return p_n - - def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 7: - Interpolates methods 4 and 5. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - p_n = self.method4(p_n, references, hypothesis, hyp_len) - p_n = self.method5(p_n, references, hypothesis, hyp_len) - return p_n + return [((p_i[0] + self.epsilon), p_i[1]) if p_i[0] == 0 else p_i for p_i in p_n] diff --git a/codebleu/weighted_ngram_match.py b/codebleu/weighted_ngram_match.py index 507cb76..3f37914 100644 --- a/codebleu/weighted_ngram_match.py +++ b/codebleu/weighted_ngram_match.py @@ -156,8 +156,8 @@ def corpus_bleu( # For each order of ngram, calculate the numerator and # denominator for the corpus-level modified precision. for i, _ in enumerate(weights, start=1): - p_i_numeraotr, p_i_denominator = modified_recall(references, hypothesis, i) - p_numerators[i] += p_i_numeraotr + p_i_numerator, p_i_denominator = modified_recall(references, hypothesis, i) + p_numerators[i] += p_i_numerator p_denominators[i] += p_i_denominator # Calculate the hypothesis length and the closest reference length. @@ -400,133 +400,8 @@ def __init__(self, epsilon=0.1, alpha=5, k=5): self.alpha = alpha self.k = k - def method0(self, p_n, *args, **kwargs): - """ - No smoothing. - """ - p_n_new = [] - for i, p_i in enumerate(p_n): - if p_i[0] != 0: - p_n_new.append(p_i) - else: - _msg = str( - "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" - "Therefore the BLEU score evaluates to 0, independently of\n" - "how many N-gram overlaps of lower order it contains.\n" - "Consider using lower n-gram order or use " - "SmoothingFunction()" - ).format(i + 1) - warnings.warn(_msg) - # When numerator==0 where denonminator==0 or !=0, the result - # for the precision score should be equal to 0 or undefined. - # Due to BLEU geometric mean computation in logarithm space, - # we we need to take the return sys.float_info.min such that - # math.log(sys.float_info.min) returns a 0 precision score. - p_n_new.append(sys.float_info.min) - return p_n_new - def method1(self, p_n, *args, **kwargs): """ Smoothing method 1: Add *epsilon* counts to precision with 0 counts. """ return [((p_i[0] + self.epsilon), p_i[1]) if p_i[0] == 0 else p_i for p_i in p_n] - - def method2(self, p_n, *args, **kwargs): - """ - Smoothing method 2: Add 1 to both numerator and denominator from - Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of - machine translation quality using longest common subsequence and - skip-bigram statistics. In ACL04. - """ - return [(p_i[0] + 1, p_i[1] + 1) for p_i in p_n] - - def method3(self, p_n, *args, **kwargs): - """ - Smoothing method 3: NIST geometric sequence smoothing - The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each - precision score whose matching n-gram count is null. - k is 1 for the first 'n' value for which the n-gram match count is null/ - For example, if the text contains: - - one 2-gram match - - and (consequently) two 1-gram matches - the n-gram count for each individual precision score would be: - - n=1 => prec_count = 2 (two unigrams) - - n=2 => prec_count = 1 (one bigram) - - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) - - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) - """ - incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. - for i, p_i in enumerate(p_n): - if p_i.numerator == 0: - p_n[i] = 1 / (2**incvnt * p_i.denominator) - incvnt += 1 - return p_n - - def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 4: - Shorter translations may have inflated precision values due to having - smaller denominators; therefore, we give them proportionally - smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry - suggests dividing by 1/ln(len(T)), where T is the length of the translation. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - for i, p_i in enumerate(p_n): - if p_i.numerator == 0 and hyp_len != 0: - incvnt = i + 1 * self.k / math.log(hyp_len) # Note that this K is different from the K from NIST. - p_n[i] = incvnt / p_i.denominator - return p_n - - def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 5: - The matched counts for similar values of n should be similar. To a - calculate the n-gram matched count, it averages the n−1, n and n+1 gram - matched counts. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - m = {} - # Requires an precision value for an addition ngram order. - p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] - m[-1] = p_n[0] + 1 - for i, p_i in enumerate(p_n): - p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 - m[i] = p_n[i] - return p_n - - def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 6: - Interpolates the maximum likelihood estimate of the precision *p_n* with - a prior estimate *pi0*. The prior is estimated by assuming that the ratio - between pn and pn−1 will be the same as that between pn−1 and pn−2; from - Gao and He (2013) Training MRF-Based Phrase Translation Models using - Gradient Ascent. In NAACL. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - # This smoothing only works when p_1 and p_2 is non-zero. - # Raise an error with an appropriate message when the input is too short - # to use this smoothing technique. - assert p_n[2], "This smoothing method requires non-zero precision for bigrams." - for i, p_i in enumerate(p_n): - if i in [0, 1]: # Skips the first 2 orders of ngrams. - continue - else: - pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] - # No. of ngrams in translation that matches the reference. - m = p_i.numerator - # No. of ngrams in translation. - ngrams_count = sum(1 for _ in ngrams(hypothesis, i + 1)) - # Calculates the interpolated precision. - p_n[i] = (m + self.alpha * pi0) / (ngrams_count + self.alpha) - return p_n - - def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): - """ - Smoothing method 7: - Interpolates methods 4 and 5. - """ - hyp_len = hyp_len if hyp_len else len(hypothesis) - p_n = self.method4(p_n, references, hypothesis, hyp_len) - p_n = self.method5(p_n, references, hypothesis, hyp_len) - return p_n diff --git a/pyproject.toml b/pyproject.toml index 371d2da..22d561c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,14 +40,12 @@ exclude = ["tests", "tests.*", "codebleu.parser.tree-sitter"] "*" = ["py.typed", "*.txt", "*.so", "*.dylib", "*.dll", "keywords/*"] - [project.scripts] codebleu = "codebleu.__main__:main" [project.urls] homepage = "https://github.com/k4black/codebleu" - [project.optional-dependencies] test = [ "pytest >=7.0.0,<8.0.0", @@ -60,10 +58,9 @@ test = [ "flake8 >=6.0.0,<7.0.0", "ruff >=0.0.275,<0.2.0", "isort >=5.0.0,<6.0.0", + "nltk >=3.0.0,<4.0.0", ] - - [tool.setuptools.dynamic] version = {file = "VERSION"} diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index 6422d51..2c41935 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -10,12 +10,8 @@ @pytest.mark.parametrize( ["predictions", "references", "codebleu"], [ - ( - ["some rannnndom words in length more than 3"], - ["def test ( ) :\n pass"], - 0.25, - ), # 'cause data_flow is 0 and considered as 1 - (["def bar ( y , x ) :\n a = x * x\n return a"], ["def foo ( x ) :\n return x"], 0.38), + (["some rannnndom words in length more than 3"], ["def test ( ) :\n pass"], 0.25), # cause data_flow=1 + (["def bar ( y , x ) :\n a = x * x\n return a"], ["def foo ( x ) :\n return x"], 0.36), (["def foo ( x ) :\n return x * x"], ["def bar ( x ) :\n return x"], 0.61), (["def bar ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 0.85), (["def foo ( x ) :\n return x"], ["def foo ( x ) :\n return x"], 1.0), @@ -24,8 +20,6 @@ def test_simple_cases(predictions: List[Any], references: List[Any], codebleu: float) -> None: result = calc_codebleu(references, predictions, "python") logging.debug(result) - print(result) - assert False assert result["codebleu"] == pytest.approx(codebleu, 0.01) From dc7a0163e2371479ce9aca986583a05367497f63 Mon Sep 17 00:00:00 2001 From: Konstantin Chernyshev Date: Thu, 16 Nov 2023 18:49:28 +0100 Subject: [PATCH 10/10] style: fix ruff --- codebleu/bleu.py | 3 --- codebleu/weighted_ngram_match.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/codebleu/bleu.py b/codebleu/bleu.py index 8db2e7a..5bb2c9a 100644 --- a/codebleu/bleu.py +++ b/codebleu/bleu.py @@ -9,10 +9,7 @@ """BLEU score implementation.""" import math -import sys -import warnings from collections import Counter -from typing import Any from .utils import ngrams diff --git a/codebleu/weighted_ngram_match.py b/codebleu/weighted_ngram_match.py index 3f37914..d651d7c 100644 --- a/codebleu/weighted_ngram_match.py +++ b/codebleu/weighted_ngram_match.py @@ -13,11 +13,8 @@ """BLEU score implementation.""" import math -import sys -import warnings from collections import Counter -from .bleu import modified_precision from .utils import ngrams