Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions sdk/diffgram/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
self.directory_id = None
self.name_to_file_id = None

self.client_id = client_id
self.client_secret = client_secret

if init_default_directory is True:
self.set_default_directory(directory = self.directory)
Expand All @@ -78,9 +80,6 @@ def __init__(
if refresh_local_label_dict is True:
self.get_label_file_dict()

self.client_id = client_id
self.client_secret = client_secret

self.label_schema_list = self.get_label_schema_list()

self.directory_list = None
Expand Down Expand Up @@ -133,6 +132,21 @@ def get_label_schema_list(self):
data = response.json()
return data

def get_attributes(self, schema_id = None):
if schema_id is None:
schema = self.get_default_label_schema()
if schema is not None:
schema_id = schema.get('id')
url = f'/api/v1/project/{self.project_string_id}/attribute/template/list'
data = {
'schema_id': schema_id,
'mode': "from_project",
}
response = self.session.post(url = self.host + url, json=data)
self.handle_errors(response)
data = response.json()
return data.get('attribute_group_list')

def get_http_auth(self):
return HTTPBasicAuth(self.client_id, self.client_secret)

Expand Down
76 changes: 68 additions & 8 deletions sdk/diffgram/core/diffgram_dataset_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@


class DiffgramDatasetIterator:

def __init__(self, project,
diffgram_file_id_list: list
max_size_cache: int = 1073741824
pool: ThreadPoolExecutor
project: 'Project'
file_cache: dict
_internal_file_list: list
current_file_index: int

def __init__(self,
project,
diffgram_file_id_list,
validate_ids = True,
max_size_cache = 1073741824,
Expand All @@ -19,6 +27,21 @@ def __init__(self, project,
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
"""
self.diffgram_file_id_list = []
self.max_size_cache = 1073741824
self.pool = None
self.file_cache = {}
self._internal_file_list = []
self.current_file_index = 0
self.start_iterator(
project = project,
diffgram_file_id_list = diffgram_file_id_list,
validate_ids = validate_ids,
max_size_cache = max_size_cache,
max_num_concurrent_fetches = max_num_concurrent_fetches)

def start_iterator(self, project, diffgram_file_id_list, validate_ids = True, max_size_cache = 1073741824,
max_num_concurrent_fetches = 25):
self.diffgram_file_id_list = diffgram_file_id_list
self.max_size_cache = max_size_cache
self.pool = ThreadPoolExecutor(max_num_concurrent_fetches)
Expand Down Expand Up @@ -62,7 +85,8 @@ def get_next_n_items(self, idx, num_items = 25):
return True

def __get_file_data_for_index(self, idx):
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True, use_session = False)
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True,
use_session = False)
instance_data = self.get_file_instances(diffgram_file)
self.save_file_in_cache(idx, instance_data)
return instance_data
Expand All @@ -88,7 +112,7 @@ def __validate_file_ids(self):
if not self.diffgram_file_id_list:
return
result = self.project.file.file_list_exists(
self.diffgram_file_id_list,
self.diffgram_file_id_list,
use_session = False)
if not result:
raise Exception(
Expand All @@ -112,6 +136,30 @@ def get_image_data(self, diffgram_file):
else:
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')

def gen_global_attrs(self, instance_list):
res = []
for inst in instance_list:
if inst['type'] != 'global':
continue
res.append(inst['attribute_groups'])
return res

def gen_tag_instances(self, instance_list):
result = []
for inst in instance_list:
if inst['type'] != 'tag':
continue
for k in list(inst.keys()):
val = inst[k]
if val is None:
inst.pop(k)
elm = {
'label': inst['label_file']['label']['name'],
'label_file_id': inst['label_file']['id'],
}
result.append(elm)
return result

def get_file_instances(self, diffgram_file):
if diffgram_file.type not in ['image', 'frame']:
raise NotImplementedError('File type "{}" is not supported yet'.format(diffgram_file['type']))
Expand All @@ -123,6 +171,9 @@ def get_file_instances(self, diffgram_file):
sample = {'image': image, 'diffgram_file': diffgram_file}
has_boxes = False
has_poly = False
has_tags = False
has_global = False
sample['raw_instance_list'] = instance_list
if 'box' in instance_types_in_file:
has_boxes = True
x_min_list, x_max_list, y_min_list, y_max_list = self.extract_bbox_values(instance_list, diffgram_file)
Expand All @@ -140,12 +191,19 @@ def get_file_instances(self, diffgram_file):
has_poly = True
mask_list = self.extract_masks_from_polygon(instance_list, diffgram_file)
sample['polygon_mask_list'] = mask_list
if 'tag' in instance_types_in_file:
has_tags = True
sample['tags'] = self.gen_tag_instances(instance_list)
if 'global' in instance_types_in_file:
has_global = True
sample['global_attributes'] = self.gen_global_attrs(instance_list)

else:
sample['polygon_mask_list'] = []

if len(instance_types_in_file) > 2 and has_boxes and has_boxes:
if len(instance_types_in_file) > 4 and has_poly and has_boxes and has_tags and has_global:
raise NotImplementedError(
'SDK only supports boxes and polygon types currently. If you want a new instance type to be supported please contact us!'
'SDK Streaming only supports boxes and polygon, tags and global attributes types currently. If you want a new instance type to be supported please contact us!'
)

label_id_list, label_name_list = self.extract_labels(instance_list)
Expand Down Expand Up @@ -174,11 +232,13 @@ def extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value =
def extract_labels(self, instance_list, allowed_instance_types = None):
label_file_id_list = []
label_names_list = []

for inst in instance_list:
if inst['type'] == 'global':
continue
if inst is None:
continue
if allowed_instance_types and inst['type'] in allowed_instance_types:
continue

label_file_id_list.append(inst['label_file']['id'])
label_names_list.append(inst['label_file']['label']['name'])

Expand Down
37 changes: 18 additions & 19 deletions sdk/diffgram/core/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

class Directory(DiffgramDatasetIterator):

def __init__(self,
client,
file_id_list_sliced = None,
init_file_ids = True,
def __init__(self,
client,
file_id_list_sliced = None,
init_file_ids = True,
validate_ids = True):

self.client = client
Expand All @@ -25,11 +25,8 @@ def __init__(self,
self.file_id_list = file_id_list_sliced
super(Directory, self).__init__(self.client, self.file_id_list, validate_ids)



def init_files(self):
self.file_id_list = self.all_file_ids()

def get_directory_list(self):
"""
Get a list of available directories for a project
Expand All @@ -50,7 +47,7 @@ def get_directory_list(self):
self.client.handle_errors(response)

data = response.json()

directory_list_json = data.get('directory_list')
default_directory_json = data.get('default_directory')

Expand All @@ -60,7 +57,6 @@ def get_directory_list(self):
directory_list = self.convert_json_to_sdk_object(directory_list_json)

return directory_list


def convert_json_to_sdk_object(self, directory_list_json):

Expand All @@ -71,18 +67,21 @@ def convert_json_to_sdk_object(self, directory_list_json):
client = self.client,
init_file_ids = False,
validate_ids = False
)
)
refresh_from_dict(new_directory, directory_json)

# note timing issue, this needs to happen after id is refreshed
new_directory.init_files()
new_directory.init_files()
new_directory.start_iterator(
project = new_directory.project,
diffgram_file_id_list = new_directory.file_id_list,
validate_ids = True
)

directory_list.append(new_directory)

return directory_list



def all_files(self):
"""
Get all the files of the directoy.
Expand All @@ -93,8 +92,8 @@ def all_files(self):
result = []
while page_num is not None:
diffgram_files = self.list_files(
limit = 1000,
page_num = page_num,
limit = 1000,
page_num = page_num,
file_view_mode = 'base')
page_num = self.file_list_metadata['next_page']
result = result + diffgram_files
Expand All @@ -105,9 +104,9 @@ def all_file_ids(self, query = None):
result = []

diffgram_ids = self.list_files(
limit = 5000,
page_num = page_num,
file_view_mode = 'ids_only',
limit = 5000,
page_num = page_num,
file_view_mode = 'ids_only',
query = query)

if diffgram_ids is False:
Expand Down Expand Up @@ -299,7 +298,6 @@ def get(self,
TODO refactor set_directory_by_name() to use this

"""

if name is None:
raise Exception("No name provided.")

Expand All @@ -312,6 +310,7 @@ def get(self,
for directory in self.client.directory_list:

if directory.nickname == name:
directory.init_files()
return directory

else:
Expand Down