diff --git a/sdk/diffgram/core/core.py b/sdk/diffgram/core/core.py index 57cbec1..15b1adc 100644 --- a/sdk/diffgram/core/core.py +++ b/sdk/diffgram/core/core.py @@ -70,6 +70,8 @@ def __init__( self.directory_id = None self.name_to_file_id = None + self.client_id = client_id + self.client_secret = client_secret if init_default_directory is True: self.set_default_directory(directory = self.directory) @@ -78,9 +80,6 @@ def __init__( if refresh_local_label_dict is True: self.get_label_file_dict() - self.client_id = client_id - self.client_secret = client_secret - self.label_schema_list = self.get_label_schema_list() self.directory_list = None @@ -133,6 +132,21 @@ def get_label_schema_list(self): data = response.json() return data + def get_attributes(self, schema_id = None): + if schema_id is None: + schema = self.get_default_label_schema() + if schema is not None: + schema_id = schema.get('id') + url = f'/api/v1/project/{self.project_string_id}/attribute/template/list' + data = { + 'schema_id': schema_id, + 'mode': "from_project", + } + response = self.session.post(url = self.host + url, json=data) + self.handle_errors(response) + data = response.json() + return data.get('attribute_group_list') + def get_http_auth(self): return HTTPBasicAuth(self.client_id, self.client_secret) diff --git a/sdk/diffgram/core/diffgram_dataset_iterator.py b/sdk/diffgram/core/diffgram_dataset_iterator.py index 237a1a0..64e172b 100644 --- a/sdk/diffgram/core/diffgram_dataset_iterator.py +++ b/sdk/diffgram/core/diffgram_dataset_iterator.py @@ -8,8 +8,16 @@ class DiffgramDatasetIterator: - - def __init__(self, project, + diffgram_file_id_list: list + max_size_cache: int = 1073741824 + pool: ThreadPoolExecutor + project: 'Project' + file_cache: dict + _internal_file_list: list + current_file_index: int + + def __init__(self, + project, diffgram_file_id_list, validate_ids = True, max_size_cache = 1073741824, @@ -19,6 +27,21 @@ def __init__(self, project, :param project (sdk.core.core.Project): A Project object from the Diffgram SDK :param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram. """ + self.diffgram_file_id_list = [] + self.max_size_cache = 1073741824 + self.pool = None + self.file_cache = {} + self._internal_file_list = [] + self.current_file_index = 0 + self.start_iterator( + project = project, + diffgram_file_id_list = diffgram_file_id_list, + validate_ids = validate_ids, + max_size_cache = max_size_cache, + max_num_concurrent_fetches = max_num_concurrent_fetches) + + def start_iterator(self, project, diffgram_file_id_list, validate_ids = True, max_size_cache = 1073741824, + max_num_concurrent_fetches = 25): self.diffgram_file_id_list = diffgram_file_id_list self.max_size_cache = max_size_cache self.pool = ThreadPoolExecutor(max_num_concurrent_fetches) @@ -62,7 +85,8 @@ def get_next_n_items(self, idx, num_items = 25): return True def __get_file_data_for_index(self, idx): - diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True, use_session = False) + diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True, + use_session = False) instance_data = self.get_file_instances(diffgram_file) self.save_file_in_cache(idx, instance_data) return instance_data @@ -88,7 +112,7 @@ def __validate_file_ids(self): if not self.diffgram_file_id_list: return result = self.project.file.file_list_exists( - self.diffgram_file_id_list, + self.diffgram_file_id_list, use_session = False) if not result: raise Exception( @@ -112,6 +136,30 @@ def get_image_data(self, diffgram_file): else: raise Exception('Pytorch datasets only support images. Please provide only file_ids from images') + def gen_global_attrs(self, instance_list): + res = [] + for inst in instance_list: + if inst['type'] != 'global': + continue + res.append(inst['attribute_groups']) + return res + + def gen_tag_instances(self, instance_list): + result = [] + for inst in instance_list: + if inst['type'] != 'tag': + continue + for k in list(inst.keys()): + val = inst[k] + if val is None: + inst.pop(k) + elm = { + 'label': inst['label_file']['label']['name'], + 'label_file_id': inst['label_file']['id'], + } + result.append(elm) + return result + def get_file_instances(self, diffgram_file): if diffgram_file.type not in ['image', 'frame']: raise NotImplementedError('File type "{}" is not supported yet'.format(diffgram_file['type'])) @@ -123,6 +171,9 @@ def get_file_instances(self, diffgram_file): sample = {'image': image, 'diffgram_file': diffgram_file} has_boxes = False has_poly = False + has_tags = False + has_global = False + sample['raw_instance_list'] = instance_list if 'box' in instance_types_in_file: has_boxes = True x_min_list, x_max_list, y_min_list, y_max_list = self.extract_bbox_values(instance_list, diffgram_file) @@ -140,12 +191,19 @@ def get_file_instances(self, diffgram_file): has_poly = True mask_list = self.extract_masks_from_polygon(instance_list, diffgram_file) sample['polygon_mask_list'] = mask_list + if 'tag' in instance_types_in_file: + has_tags = True + sample['tags'] = self.gen_tag_instances(instance_list) + if 'global' in instance_types_in_file: + has_global = True + sample['global_attributes'] = self.gen_global_attrs(instance_list) + else: sample['polygon_mask_list'] = [] - if len(instance_types_in_file) > 2 and has_boxes and has_boxes: + if len(instance_types_in_file) > 4 and has_poly and has_boxes and has_tags and has_global: raise NotImplementedError( - 'SDK only supports boxes and polygon types currently. If you want a new instance type to be supported please contact us!' + 'SDK Streaming only supports boxes and polygon, tags and global attributes types currently. If you want a new instance type to be supported please contact us!' ) label_id_list, label_name_list = self.extract_labels(instance_list) @@ -174,11 +232,13 @@ def extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value = def extract_labels(self, instance_list, allowed_instance_types = None): label_file_id_list = [] label_names_list = [] - for inst in instance_list: + if inst['type'] == 'global': + continue + if inst is None: + continue if allowed_instance_types and inst['type'] in allowed_instance_types: continue - label_file_id_list.append(inst['label_file']['id']) label_names_list.append(inst['label_file']['label']['name']) diff --git a/sdk/diffgram/core/directory.py b/sdk/diffgram/core/directory.py index f3dc269..de00d50 100644 --- a/sdk/diffgram/core/directory.py +++ b/sdk/diffgram/core/directory.py @@ -7,10 +7,10 @@ class Directory(DiffgramDatasetIterator): - def __init__(self, - client, - file_id_list_sliced = None, - init_file_ids = True, + def __init__(self, + client, + file_id_list_sliced = None, + init_file_ids = True, validate_ids = True): self.client = client @@ -25,11 +25,8 @@ def __init__(self, self.file_id_list = file_id_list_sliced super(Directory, self).__init__(self.client, self.file_id_list, validate_ids) - - def init_files(self): self.file_id_list = self.all_file_ids() - def get_directory_list(self): """ Get a list of available directories for a project @@ -50,7 +47,7 @@ def get_directory_list(self): self.client.handle_errors(response) data = response.json() - + directory_list_json = data.get('directory_list') default_directory_json = data.get('default_directory') @@ -60,7 +57,6 @@ def get_directory_list(self): directory_list = self.convert_json_to_sdk_object(directory_list_json) return directory_list - def convert_json_to_sdk_object(self, directory_list_json): @@ -71,18 +67,21 @@ def convert_json_to_sdk_object(self, directory_list_json): client = self.client, init_file_ids = False, validate_ids = False - ) + ) refresh_from_dict(new_directory, directory_json) # note timing issue, this needs to happen after id is refreshed - new_directory.init_files() + new_directory.init_files() + new_directory.start_iterator( + project = new_directory.project, + diffgram_file_id_list = new_directory.file_id_list, + validate_ids = True + ) directory_list.append(new_directory) return directory_list - - def all_files(self): """ Get all the files of the directoy. @@ -93,8 +92,8 @@ def all_files(self): result = [] while page_num is not None: diffgram_files = self.list_files( - limit = 1000, - page_num = page_num, + limit = 1000, + page_num = page_num, file_view_mode = 'base') page_num = self.file_list_metadata['next_page'] result = result + diffgram_files @@ -105,9 +104,9 @@ def all_file_ids(self, query = None): result = [] diffgram_ids = self.list_files( - limit = 5000, - page_num = page_num, - file_view_mode = 'ids_only', + limit = 5000, + page_num = page_num, + file_view_mode = 'ids_only', query = query) if diffgram_ids is False: @@ -299,7 +298,6 @@ def get(self, TODO refactor set_directory_by_name() to use this """ - if name is None: raise Exception("No name provided.") @@ -312,6 +310,7 @@ def get(self, for directory in self.client.directory_list: if directory.nickname == name: + directory.init_files() return directory else: