Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 40 additions & 34 deletions AstToEcoreConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def __init__(self, resource_set: ResourceSet, repository, write_in_file, output_
Raises:
ValueError: If the repository is None or empty.
"""
self.current_module_name = None
if repository is None or repository == '':
raise ValueError('Directory is required')

self.root_directory = repository.replace('\\', '/')
self.epackage = resource_set.get_resource(URI('Basic.ecore')).contents[0]
self.graph = self.epackage.getEClassifier('TypeGraph')(
self.e_package = resource_set.get_resource(URI('Basic.ecore')).contents[0]
self.graph = self.e_package.getEClassifier('TypeGraph')(
tName=self.root_directory.split('/')[-1])

# initialize internal structures
Expand All @@ -54,6 +55,7 @@ def __init__(self, resource_set: ResourceSet, repository, write_in_file, output_
# entries: [module_node, module_name, package_node, package_name]
self.imported_libraries = []
self.imported_package = None
self.current_parent = None

python_files = [os.path.join(root, file) for root, _, files in os.walk(
self.root_directory) for file in files if file.endswith('.py')]
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(self, resource_set: ResourceSet, repository, write_in_file, output_
# create and process modules with contained program entities
for file_path in python_files:
try:
self.process_file(file_path)
self.process_file(str(file_path))
except Exception as e:
if 'invalid syntax' in str(e):
logger.warning(f'skipped: {file_path}')
Expand Down Expand Up @@ -316,7 +318,7 @@ def create_package_hierarchy(self, parent_package, subpackage_names, lib_flag=Tr
if e == 0:
package_node.parent = parent_package
else:
package_node.parent = current_parent
package_node.parent = self.current_parent
if lib_flag is True:
self.imported_libraries.append(
[None, None, package_node, element_lib])
Expand All @@ -326,8 +328,8 @@ def create_package_hierarchy(self, parent_package, subpackage_names, lib_flag=Tr
[package_node, element_lib, parent_package])
else:
self.package_list.append(
[package_node, element_lib, current_parent])
current_parent = package_node
[package_node, element_lib, self.current_parent])
self.current_parent = package_node
self.imported_package = package_node

def create_imported_method_call(self, module_node, method_name, caller_node):
Expand Down Expand Up @@ -430,12 +432,13 @@ def set_external_module_calls(self):
if obj.eClass.name == NodeTypes.METHOD_DEFINITION.value:
self.create_method_call(obj, method_name, caller_node)
if module_node is None:
# if len==1 simple import .. statement, included only if import is used (in that case len>1)
# if len==1 simple import statement, included only if import is used (in that case len>1)
if len(split_import) > 1:
self.call_imported_library.append(
[caller_node, imported_instance])

def set_import_names(self, split_import):
@staticmethod
def set_import_names(split_import):
"""
Sets the names for imported modules, classes, and methods.

Expand Down Expand Up @@ -539,9 +542,9 @@ def create_method_call(self, method_node, method_name, caller_node):
self.create_calls(caller_node, method_node)

def check_for_missing_nodes(self):
"""check_list contains all classes with method defs that are created during conversion.
They are compared to the classes with meth defs found in modules in the type graph at the end,
those not found need to be appended to a module, otherwise the meth defs are missing.
"""check_list contains all classes with method def that are created during conversion.
They are compared to the classes with meth def found in modules in the type graph at the end,
those not found need to be appended to a module, otherwise the meth def are missing.
Entire modules are missing! Perhaps because only .py files are processed. They are created and appended to
their packages, which are also created when they are not in the type graph."""
# check if every created TClass node is in type graph
Expand All @@ -567,7 +570,7 @@ def check_for_missing_nodes(self):
ref, ty = self.get_reference_by_name(obj.tName)
if ref is not None:
imported = ref.split('.')
# if len==1 simple import .. statement, included only if import is used (in that case len>1)
# if len==1 simple import statement, included only if import is used (in that case len>1)
if len(imported) > 1:
package_name, subpackage_names, module_name, class_name, method_name = self.set_import_names(
imported)
Expand Down Expand Up @@ -672,14 +675,14 @@ def create_missing_module(self, module_name, class_node, package_node):
module_node.namespace = package_node
self.graph.modules.append(module_node)

def get_epackage(self):
def get_e_package(self):
"""
Retrieves the EPackage associated with the graph.

Returns:
The EPackage instance.
"""
return self.epackage
return self.e_package

def get_graph(self):
"""
Expand All @@ -690,17 +693,17 @@ def get_graph(self):
"""
return self.graph

def create_ecore_instance(self, type):
def create_ecore_instance(self, ecore_type):
"""
Creates an Ecore instance of the specified type.

Args:
type: The type of the Ecore instance to create.
ecore_type: The type of the Ecore instance to create.

Returns:
The created Ecore instance.
"""
return self.epackage.getEClassifier(type.value)()
return self.e_package.getEClassifier(ecore_type.value)()

def get_current_module(self):
"""
Expand Down Expand Up @@ -780,7 +783,7 @@ def process_file(self, path):
self.classes_without_module.remove(class_object)
class_object.delete()

# added errors='ignore' to fix encoding issues in some repositories ('charmap cannot decode byte..')
# added errors='ignore' to fix encoding issues in some repositories ('char-map cannot decode byte')
with open(path, 'r', errors='ignore') as file:
code = file.read()
# added following to fix some invalid character and syntax errors
Expand Down Expand Up @@ -1015,8 +1018,8 @@ def add_instance(self, instance_name, class_name):
instance_name (str?): The name of the instance.
class_name (str?): The name of the class to which the instance belongs.
"""
reference, type = self.get_reference_by_name(class_name)
if reference is not None and type == 0:
reference, reference_type = self.get_reference_by_name(class_name)
if reference is not None and reference_type == 0:
classes = class_name.split('.')[1:]
classes.insert(0, reference)
class_name = ".".join(classes)
Expand All @@ -1033,7 +1036,8 @@ def remove_instance(self, class_name):
if instance[1] == class_name:
self.instances.remove(instance)

def get_method_def_in_class(self, name, class_node):
@staticmethod
def get_method_def_in_class(name, class_node):
"""
Checks if a method definition exists in a class.

Expand All @@ -1049,7 +1053,8 @@ def get_method_def_in_class(self, name, class_node):
return method_def
return None

def get_method_def_in_module(self, method_name, module):
@staticmethod
def get_method_def_in_module(method_name, module):
"""
Checks if a method definition exists in a module.

Expand All @@ -1060,12 +1065,12 @@ def get_method_def_in_module(self, method_name, module):
Returns:
The method definition node or None if not found.
"""
for object in module.contains:
if object.eClass.name == NodeTypes.METHOD_DEFINITION.value:
if object.signature.method.tName == method_name:
return object
if object.eClass.name == NodeTypes.CLASS.value:
for meth in object.defines:
for module_object in module.contains:
if module_object.eClass.name == NodeTypes.METHOD_DEFINITION.value:
if module_object.signature.method.tName == method_name:
return module_object
if module_object.eClass.name == NodeTypes.CLASS.value:
for meth in module_object.defines:
if meth.signature.method.tName == method_name:
return meth
return None
Expand Down Expand Up @@ -1116,11 +1121,12 @@ def create_method_signature(self, method_node, name, arguments):

method_node.signature = method_signature

# for interal structure
# for internal structure
module_node = self.get_current_module()
self.method_list.append([method_node, name, module_node])

def get_calls(self, caller_node, called_node):
@staticmethod
def get_calls(caller_node, called_node):
"""
Checks if a call already exists between two nodes.

Expand Down Expand Up @@ -1222,12 +1228,12 @@ def create_inheritance_structure(self, node, child):
"""
base_node = None
if isinstance(node, ast.Name):
base_node, type = self.graph_class.get_reference_by_name(node.id)
base_node, base_type = self.graph_class.get_reference_by_name(node.id)
if base_node is None:
base_node = self.graph_class.get_class_by_name(
node.id, module=self.graph_class.get_current_module())
base_node.childClasses.append(child)
elif isinstance(base_node, str) and type == 0:
elif isinstance(base_node, str) and base_type == 0:
import_parent = None
for import_class in base_node.split('.'):
import_node = self.graph_class.get_class_by_name(
Expand Down Expand Up @@ -1271,7 +1277,7 @@ def visit_ClassDef(self, node):
self.graph_class.create_method_signature(
method_node, method_name, item.args.args)
class_node.defines.append(method_node)
# to search for missing meth defs later
# to search for missing meth def later
self.graph_class.check_list.append(class_node)
self.generic_visit(node)

Expand Down Expand Up @@ -1378,7 +1384,7 @@ def visit_Call(self, node):
[self.current_module, caller_node, instance])

# for calls of imported instances, both within repo and external libraries
instance_from_graph, type = self.graph_class.get_reference_by_name(
instance_from_graph, instance_type = self.graph_class.get_reference_by_name(
instance.replace(f".{instance.split('.')[-1]}", ''))

# this is necessary to get all the called methods' names correctly
Expand Down
82 changes: 45 additions & 37 deletions CustomDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import Dataset
from torch_geometric.data import Data

from DataformatUtils import convert_edge_dim, convert_list_to_floattensor, convert_list_to_longtensor, \
from DataformatUtils import convert_edge_dim, convert_list_to_float_tensor, convert_list_to_long_tensor, \
convert_hashed_names_to_float
from Encoder import multi_hot_encoding
from GraphClasses import defined_labels
Expand All @@ -22,6 +22,7 @@ def __init__(self, directory, label_list=None):
label_list (str, optional): The path to the Excel file containing labeled graphs.
If provided, the labels will be processed and encoded.
"""
self.class_elements = []
if label_list is not None:
try:
self.encoded_labels = self.convert_labeled_graphs(label_list)
Expand All @@ -33,6 +34,8 @@ def __init__(self, directory, label_list=None):
self.directory = directory
self.graph_names = []
self.graph_dir = os.listdir(directory)
self.graph = None

for g, graph in enumerate(self.graph_dir):
if '_nodefeatures.csv' in graph:
graph_name = graph.removesuffix('_nodefeatures.csv')
Expand Down Expand Up @@ -69,32 +72,36 @@ def __getitem__(self, index):
and optionally the label.
"""
graph_name = self.graph_names[index]
for g, graph in enumerate(self.graph_dir):
for g, self.graph in enumerate(self.graph_dir):
try:
if f'{graph_name}_nodefeatures.csv' == graph:
if f'{graph_name}_nodefeatures.csv' == self.graph:
node_features = pd.read_csv(
f'{self.directory}/{graph}', header=None) # load csv file
self.x = convert_hashed_names_to_float(node_features)
if f'{graph_name}_A.csv' == graph:
f'{self.directory}/{self.graph}', header=None) # load csv file
self.x = convert_hashed_names_to_float(node_features.to_numpy())
if f'{graph_name}_A.csv' == self.graph:
adjacency = pd.read_csv(
f'{self.directory}/{graph}', header=None)
edge_tensor = convert_list_to_longtensor(adjacency)
f'{self.directory}/{self.graph}', header=None)
edge_tensor = convert_list_to_long_tensor(adjacency.values.tolist())
self.edge_index = convert_edge_dim(edge_tensor)
if f'{graph_name}_edge_attributes.csv' == graph:
if f'{graph_name}_edge_attributes.csv' == self.graph:
edge_attributes = pd.read_csv(
f'{self.directory}/{graph}', header=None)
self.edge_attr = convert_list_to_floattensor(
edge_attributes)
f'{self.directory}/{self.graph}', header=None)
self.edge_attr = convert_list_to_float_tensor(
edge_attributes.values.tolist())
except Exception as e:
print(graph, e)
print(self.graph, e)
if hasattr(self, 'x') and hasattr(self, 'edge_index'):
graph = Data(x=self.x, edge_index=self.edge_index)
self.graph = Data(x=self.x, edge_index=self.edge_index)
if hasattr(self, 'y'):
label = self.y[index]
graph.y = label
self.graph.y = label
if hasattr(self, 'edge_attr'):
graph.edge_attr = self.edge_attr
return graph
self.graph.edge_attr = self.edge_attr
return self.graph

def __iter__(self):
for index in range(len(self)):
yield self[index]

def sort_labels(self):
"""
Expand All @@ -105,16 +112,16 @@ def sort_labels(self):
torch.FloatTensor: A tensor containing the sorted labels for the graphs.
"""
label_list = list(self.encoded_labels)
sorted = None
sorted_labels = None
for n, item in enumerate(self.graph_names):
for i, name in enumerate(label_list):
if item == name[0]:
label = name[1]
if sorted is None:
sorted = np.array(label, dtype=np.float16)
if sorted_labels is None:
sorted_labels = np.array(label, dtype=np.float16)
else:
sorted = np.vstack((sorted, label)).astype(np.float16)
y = torch.FloatTensor(sorted)
sorted_labels = np.vstack((sorted_labels, label)).astype(np.float16)
y = torch.FloatTensor(sorted_labels)
return y

'''takes directory path of excel file with labeled repositories as input and converts the
Expand Down Expand Up @@ -142,13 +149,13 @@ def convert_labeled_graphs(self, labels):

# iterate over loaded file and retrieve labels
for row in resource.iterrows():
object = row[1]
row_data = row[1]
# column header containing repository url
url = object.get('html_url')
url = row_data.get('html_url')
repo_name = url.split('/')[-1] # last element is repository name
graph_names.append(repo_name)
# column header containing label
type_label = object.get('final type')
type_label = row_data.get('final type')
graph_labels.append(type_label)

self.class_elements = self.count_class_elements(
Expand All @@ -159,7 +166,8 @@ def convert_labeled_graphs(self, labels):
file = zip(graph_names, encoded_nodes)
return file

def count_class_elements(self, labels):
@staticmethod
def count_class_elements(labels):
"""
Counts the number of occurrences of each class type in the provided labels.

Expand Down Expand Up @@ -209,18 +217,18 @@ def check_dataset(self):
"""
for i, item in enumerate(self.graph_names):
graph_name = self.graph_names[i]
for g, graph in enumerate(self.graph_dir):
for g in self.graph_dir:
try:
if f'{graph_name}_nodefeatures.csv' == graph:
node_features = pd.read_csv(
f'{self.directory}/{graph}', header=None)
if f'{graph_name}_A.csv' == graph:
adjacency = pd.read_csv(
f'{self.directory}/{graph}', header=None)
if f'{graph_name}_edge_attributes.csv' == graph:
edge_attributes = pd.read_csv(
f'{self.directory}/{graph}', header=None)
if f'{graph_name}_nodefeatures.csv' == g:
pd.read_csv(
f'{self.directory}/{g}', header=None)
if f'{graph_name}_A.csv' == g:
pd.read_csv(
f'{self.directory}/{g}', header=None)
if f'{graph_name}_edge_attributes.csv' == g:
pd.read_csv(
f'{self.directory}/{g}', header=None)
except Exception as e:
if graph_name in self.graph_names:
self.graph_names.remove(graph_name)
print(f'{graph}, {e}, removing {graph_name} from dataset')
print(f'{g}, {e}, removing {graph_name} from dataset')
Loading
Loading