Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from . import bleu, dataflow_match, syntax_match, weighted_ngram_match

PACKAGE_DIR = Path(__file__).parent
# AVAILABLE_LANGS = ['java', 'javascript', 'c_sharp', 'php', 'go', 'python', 'ruby']
AVAILABLE_LANGS = ["java", "javascript", "c_sharp", "php", "c", "cpp", "python"] # keywords available


Expand Down Expand Up @@ -56,7 +55,8 @@ def tokenizer(s):
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)

# calculate weighted ngram match
keywords = [x.strip() for x in open(keywords_dir / (lang + ".txt"), "r", encoding="utf-8").readlines()]
with open(keywords_dir / (lang + ".txt"), "r", encoding="utf-8") as f:
keywords = [x.strip() for x in f.readlines()]

def make_weights(reference_tokens, key_word_list):
return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens}
Expand All @@ -74,15 +74,6 @@ def make_weights(reference_tokens, key_word_list):
# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang, lang_so_file)

# print(
# "ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}".format(
# ngram_match_score,
# weighted_ngram_match_score,
# syntax_match_score,
# dataflow_match_score,
# )
# )

alpha, beta, gamma, theta = weights
code_bleu_score = (
alpha * ngram_match_score
Expand All @@ -91,8 +82,6 @@ def make_weights(reference_tokens, key_word_list):
+ theta * (dataflow_match_score or 1)
)

# print("CodeBLEU score: ", code_bleu_score)

return {
"codebleu": code_bleu_score,
"ngram_match_score": ngram_match_score,
Expand Down
3 changes: 2 additions & 1 deletion codebleu/dataflow_match.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging

from tree_sitter import Language, Parser

Expand Down Expand Up @@ -67,7 +68,7 @@ def corpus_dataflow_match(references, candidates, lang, langso_so_file):
match_count += 1
normalized_cand_dfg.remove(dataflow)
if total_count == 0:
print(
logging.warning(
"WARNING: There is no reference data-flows extracted from the whole corpus, "
"and the data-flow match score degenerates to 0. Please consider ignoring this score."
)
Expand Down
2 changes: 0 additions & 2 deletions codebleu/weighted_ngram_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def corpus_bleu(
# it tries to retain the Fraction object as much as the
# smoothing method allows.
p_n = smoothing_function(p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths)
# pdb.set_trace()
s = (w_i * math.log(p_i[0] / p_i[1]) for w_i, p_i in zip(weights, p_n))
s = bp * math.exp(math.fsum(s))
return s
Expand All @@ -212,7 +211,6 @@ def modified_recall(references, hypothesis, n):
"""
# Extracts all ngrams in hypothesis
# Set an empty Counter if hypothesis is empty.
# pdb.set_trace()
numerator = 0
denominator = 0

Expand Down
2 changes: 1 addition & 1 deletion evaluate_app/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _info(self):
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
# workarounds as this file have to be named codebleu (evaluate library requirement)
self.codebleu_package = importlib.import_module('codebleu')
self.codebleu_package = importlib.import_module("codebleu")
pass

def _compute(self, predictions, references, lang, weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None):
Expand Down
24 changes: 13 additions & 11 deletions tests/test_codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@
from typing import Any, List

import pytest
import logging

from codebleu.codebleu import AVAILABLE_LANGS, calc_codebleu


@pytest.mark.parametrize(['predictions', 'references', 'codebleu'], [
(['some rannnndom words in length more than 3'], ['def test ( ) :\n pass'], 0.25), # 'cause data_flow is 0 and considered as 1
(['some rannnndom words in length more than 3'],
['def test ( ) :\n pass'], 0.25), # 'cause data_flow is 0 and considered as 1
(['def bar ( y , x ) :\n a = x * x\n return a'], ['def foo ( x ) :\n return x'], 0.4),
(['def foo ( x ) :\n return x * x'], ['def bar ( x ) :\n return x'], 0.6),
(['def bar ( x ) :\n return x'], ['def foo ( x ) :\n return x'], 0.8),
(['def foo ( x ) :\n return x'], ['def foo ( x ) :\n return x'], 1.0),
])
def test_simple_cases(predictions: List[Any], references: List[Any], codebleu: float) -> None:
result = calc_codebleu(references, predictions, 'python')
print(result)
logging.debug(result)
assert result['codebleu'] == pytest.approx(codebleu, 0.1)


@pytest.mark.parametrize(['lang'], [(l,) for l in AVAILABLE_LANGS])
@pytest.mark.parametrize(['lang'], [(lang,) for lang in AVAILABLE_LANGS])
def test_exact_match_works_for_all_langs(lang: str) -> None:
predictions = references = ['some matching string a couple of times']
assert calc_codebleu(references, predictions, lang)['codebleu'] == 1.0
Expand All @@ -36,7 +38,7 @@ def test_exact_match_works_for_all_langs(lang: str) -> None:
])
def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None:
result = calc_codebleu(references, predictions, lang)
print(result)
logging.debug(result)
assert result['codebleu'] == pytest.approx(0.6, 0.1)


Expand All @@ -54,17 +56,17 @@ def test_error_when_input_length_mismatch() -> None:
(
['public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }'],
['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'],
0.7238
0.7019
),
(
['public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }'],
['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'],
0.8804
),
# (
# ['public static int Sign ( double c ) { return ( int ) ( ( c == 0 ) ? 0 : ( c < 0 ) ? - 1 : 1) ; }'],
# ['public static int Sign ( double d ) { return ( int ) ( ( d == 0 ) ? 0 : ( d < 0 ) ? - 1 : 1) ; }'],
# 0.8397
# ),
])
def test_code_x_glue_readme_examples(predictions: List[Any], references: List[Any], codebleu: float) -> None:
result = calc_codebleu(references, predictions, 'java')
print(result)
logging.debug(result)
assert result['codebleu'] == pytest.approx(codebleu, 0.01)


Expand Down