From b221e5876f878cc87f07b02b69c6a031e80e714b Mon Sep 17 00:00:00 2001 From: Pablo Date: Tue, 31 May 2022 08:58:28 -0600 Subject: [PATCH] fix: add better label schema support --- sdk/diffgram/core/core.py | 462 +++++++++++++++++--------------- sdk/diffgram/file/view.py | 6 +- sdk/diffgram/label/label_new.py | 95 +++---- 3 files changed, 303 insertions(+), 260 deletions(-) diff --git a/sdk/diffgram/core/core.py b/sdk/diffgram/core/core.py index 3e4fa03..db6fea4 100644 --- a/sdk/diffgram/core/core.py +++ b/sdk/diffgram/core/core.py @@ -22,233 +22,269 @@ class Project(): + def __init__( + self, + project_string_id, + client_id = None, + client_secret = None, + debug = False, + staging = False, + host = None + ): + + self.session = requests.Session() + adapter = requests.adapters.HTTPAdapter(pool_connections = 30, pool_maxsize = 30) + self.session.mount('http://', adapter) + self.session.mount('https://', adapter) + self.project_string_id = None + + self.debug = debug + self.staging = staging + if host is None: + if self.debug is True: + self.host = "http://127.0.0.1:8085" + print("Debug", __version__) + elif self.staging is True: + self.host = "https://20200110t142358-dot-walrus-dot-diffgram-001.appspot.com/" + else: + self.host = "https://diffgram.com" + else: + self.host = host + self.directory_id = None + self.name_to_file_id = None + self.auth( + project_string_id = project_string_id, + client_id = client_id, + client_secret = client_secret) + self.client_id = client_id + self.client_secret = client_secret + + self.file = FileConstructor(self) + self.train = Train(self) + self.job = Job(self) + self.guide = Guide(self) + self.directory = Directory(self, validate_ids = False) + self.export = Export(self) + self.task = Task(client = self) + self.label_schema_list = self.get_label_schema_list() + + def get_member_list(self): + url = '/api/project/{}/view'.format(self.project_string_id) + response = self.session.get(url = self.host + url) + self.handle_errors(response) + data = response.json() + return data['project']['member_list'] + + def get_label_schema_by_id(self, id): + if self.label_schema_list is None or len(self.label_schema_list) == 0: + self.label_schema_list = self.get_label_schema_list() + for s in self.label_schema_list: + if s['id'] == id: + return s + + def get_label_schema_by_name(self, name): + if self.label_schema_list is None or len(self.label_schema_list) == 0: + self.label_schema_list = self.get_label_schema_list() + for s in self.label_schema_list: + if s['name'] == name: + return s + + def get_default_label_schema(self): + if self.label_schema_list is None or len(self.label_schema_list) == 0: + self.label_schema_list = self.get_label_schema_list() + + return self.label_schema_list[0] + + def get_label_list(self, schema_id = None): + url = f'/api/project/{self.project_string_id}/labels' + if schema_id is None: + schema = self.get_default_label_schema() + if schema is not None: + schema_id = schema.get('id') + + params = {'schema_id': schema_id} + response = self.session.get(url = self.host + url, params=params) + self.handle_errors(response) + data = response.json() + return data.get('labels_out') + + def get_label_schema_list(self): + url = f'/api/v1/project/{self.project_string_id}/labels-schema' + response = self.session.get(url = self.host + url) + self.handle_errors(response) + data = response.json() + return data + + def get_http_auth(self): + return HTTPBasicAuth(self.client_id, self.client_secret) + + def get_label( + self, + name = None, + schema_id = None, + name_list = None): + """ + name, str + name_list, list, optional + + Name must be an exact match to label name. + + If a name_list is provided it will construct a list of + objects that match that name. + + Returns + None if not found. + File object of type Label if found. + List of File objects if a proj is provided. + """ + if self.name_to_file_id is None: + self.get_label_file_dict() + + if name_list: + out = [] + for name in name_list: + out.append(self.get_label(name)) + return out + + id = self.name_to_file_id.get(name) + + if id is None: + return None + + file = File(id = id) + return file + + def get_model( + self, + name = None, + local = False): + + brain = Brain( + client = self, + name = name, + local = local + ) + + return brain + + def handle_errors(self, + response): + + """ + Upon a bad request (400), our error log contains + good information to raise. + + We also catch a few more common codes to + try and print simpler messages. + + Otherwise expects this to be caught by raise_for_status() + if applicable + https://2.python-requests.org/en/master/_modules/requests/models/#Response.raise_for_status + + This is under the assumption that we generaly call response.json() + after this, and that fails in poor way if there is no json available. + """ + + # Default + if response.status_code == 200: + return + + # Errors + if response.status_code == 400: + try: + raise Exception(response.json()["log"]["error"]) + except: + raise Exception(response.text) + + if response.status_code == 403: + raise Exception("Invalid Permission", response.text) + + if response.status_code == 404: + raise (Exception("404 Not Found" + response.text)) + + if response.status_code == 429: + raise Exception( + "Rate Limited. Please add buffer between calls eg time.sleep(1). Otherwise, please try again later. Else contact us if this persists.") + + if response.status_code == 500: + raise Exception("Internal error, please try again later.") + + raise_for_status = response.raise_for_status() + if raise_for_status: + Exception(raise_for_status) + + def auth(self, + project_string_id, + client_id = None, + client_secret = None, + set_default_directory = True, + refresh_local_label_dict = True + ): + """ + Define authorization configuration - def __init__( - self, - project_string_id, - client_id = None, - client_secret = None, - debug = False, - staging = False, - host = None - ): - - self.session = requests.Session() - adapter = requests.adapters.HTTPAdapter(pool_connections = 30, pool_maxsize = 30) - self.session.mount('http://', adapter) - self.session.mount('https://', adapter) - self.project_string_id = None - - self.debug = debug - self.staging = staging - if host is None: - if self.debug is True: - self.host = "http://127.0.0.1:8085" - print("Debug", __version__) - elif self.staging is True: - self.host = "https://20200110t142358-dot-walrus-dot-diffgram-001.appspot.com/" - else: - self.host = "https://diffgram.com" - else: - self.host = host - self.directory_id = None - self.name_to_file_id = None - self.auth( - project_string_id = project_string_id, - client_id = client_id, - client_secret = client_secret) - self.client_id = client_id - self.client_secret = client_secret - self.file = FileConstructor(self) - self.train = Train(self) - self.job = Job(self) - self.guide = Guide(self) - self.directory = Directory(self, validate_ids = False) - self.export = Export(self) - self.task = Task(client = self) - - def get_member_list(self): - url = '/api/project/{}/view'.format(self.project_string_id) - response = self.session.get(url=self.host + url) - self.handle_errors(response) - data = response.json() - return data['project']['member_list'] - - def get_http_auth(self): - return HTTPBasicAuth(self.client_id, self.client_secret) - - def get_label( - self, - name=None, - name_list=None): - """ - name, str - name_list, list, optional - - Name must be an exact match to label name. - - If a name_list is provided it will construct a list of - objects that match that name. - - Returns - None if not found. - File object of type Label if found. - List of File objects if a name_list is provided. - """ - if self.name_to_file_id is None: - self.get_label_file_dict() - - if name_list: - out = [] - for name in name_list: - out.append(self.get_label(name)) - return out - - id = self.name_to_file_id.get(name) - - if id is None: - return None - - file = File(id = id) - return file - - - def get_model( - self, - name = None, - local = False): - - - brain = Brain( - client = self, - name = name, - local = local - ) - - return brain - - - def handle_errors(self, - response): - - """ - Upon a bad request (400), our error log contains - good information to raise. - - We also catch a few more common codes to - try and print simpler messages. - - Otherwise expects this to be caught by raise_for_status() - if applicable - https://2.python-requests.org/en/master/_modules/requests/models/#Response.raise_for_status - - This is under the assumption that we generaly call response.json() - after this, and that fails in poor way if there is no json available. - """ - - # Default - if response.status_code == 200: - return - - # Errors - if response.status_code == 400: - try: - raise Exception(response.json()["log"]["error"]) - except: - raise Exception(response.text) - - if response.status_code == 403: - raise Exception("Invalid Permission", response.text) - - if response.status_code == 404: - raise(Exception("404 Not Found" + response.text)) + If no client_id / secret is provided it assumes project is public + And if project isn't public it will return a 403 permission denied. - if response.status_code == 429: - raise Exception("Rate Limited. Please add buffer between calls eg time.sleep(1). Otherwise, please try again later. Else contact us if this persists.") + Arguments + client_id, string + client_secret, string + project_string_id, string - if response.status_code == 500: - raise Exception("Internal error, please try again later.") + Returns + None + + Future + More gracefully intial setup (ie validate upon setting) + """ + self.project_string_id = project_string_id - raise_for_status = response.raise_for_status() - if raise_for_status: - Exception(raise_for_status) + if client_id and client_secret: + self.session.auth = (client_id, client_secret) + if set_default_directory is True: + self.set_default_directory() + if refresh_local_label_dict is True: + # Refresh local labels from Diffgram project + self.get_label_file_dict() - def auth(self, - project_string_id, - client_id = None, - client_secret = None, - set_default_directory = True, - refresh_local_label_dict = True - ): - """ - Define authorization configuration + def set_default_directory(self, + directory_id = None): + """ + -> If no id is provided fetch directory list for project + and set first directory to default. + -> Sets the headers of self.session - If no client_id / secret is provided it assumes project is public - And if project isn't public it will return a 403 permission denied. + Arguments + directory_id, int, defaults to None - Arguments - client_id, string - client_secret, string - project_string_id, string + Returns + None - Returns - None + Future + TODO return error if invalid directory? - Future - More gracefully intial setup (ie validate upon setting) - """ - self.project_string_id = project_string_id + """ - if client_id and client_secret: - self.session.auth = (client_id, client_secret) - - if set_default_directory is True: - self.set_default_directory() + if directory_id: + # TODO check if valid? + # data = {} + # data["directory_id"] = directory_id + self.directory_id = directory_id + else: - if refresh_local_label_dict is True: - # Refresh local labels from Diffgram project - self.get_label_file_dict() + data = self.get_directory_list() - - - def set_default_directory(self, - directory_id=None): - """ - -> If no id is provided fetch directory list for project - and set first directory to default. - -> Sets the headers of self.session - - Arguments - directory_id, int, defaults to None - - Returns - None - - Future - TODO return error if invalid directory? - - """ - - if directory_id: - # TODO check if valid? - # data = {} - # data["directory_id"] = directory_id - self.directory_id = directory_id - else: - - data = self.get_directory_list() - - self.default_directory = data['default_directory'] - - # Hold over till refactoring (would prefer to - # just call self.directory_default.id - self.directory_id = self.default_directory['id'] - - self.directory_list = data["directory_list"] - self.session.headers.update( - {'directory_id': str(self.directory_id)}) + self.default_directory = data['default_directory'] + + # Hold over till refactoring (would prefer to + # just call self.directory_default.id + self.directory_id = self.default_directory['id'] + + self.directory_list = data["directory_list"] + self.session.headers.update( + {'directory_id': str(self.directory_id)}) # TODO review not using this pattern anymore diff --git a/sdk/diffgram/file/view.py b/sdk/diffgram/file/view.py index 026807c..d4c8014 100644 --- a/sdk/diffgram/file/view.py +++ b/sdk/diffgram/file/view.py @@ -18,7 +18,7 @@ def get_file_id(): pass -def get_label_file_dict(self, use_session = True): +def get_label_file_dict(self, schema_id = None, use_session = True): """ Get Project label file id dict for project @@ -42,11 +42,13 @@ def get_label_file_dict(self, use_session = True): endpoint = "/api/v1/project/" + self.project_string_id + \ "/labels/view/name_to_file_id" + params = {'schema_id': schema_id} if use_session: - response = self.session.get(self.host + endpoint) + response = self.session.get(self.host + endpoint, params = params) else: # Add Auth response = requests.get(self.host + endpoint, + params = params, headers = {'directory_id': str(self.directory_id)}, auth = self.get_http_auth()) diff --git a/sdk/diffgram/label/label_new.py b/sdk/diffgram/label/label_new.py index 4f23292..49c05ae 100644 --- a/sdk/diffgram/label/label_new.py +++ b/sdk/diffgram/label/label_new.py @@ -1,48 +1,53 @@ import warnings -def label_new(self, - label, - allow_duplicates=False, - print_success=True): - """ - - Arguments - self, - label_list, a list of label strings - ignore_duplicates, bool - print_success, bool - - Expects - - Returns - - """ - - # Check if already exists - name = label.get('name', None) - if not name: - raise Exception("Please provide a key of name with a value of label") - - if allow_duplicates is False: - - label_file_id = self.name_to_file_id.get(name, None) - - if label_file_id: - warnings.warn("\n\n '" + name + "' label already exists and was skipped." + \ - "\n Set allow_duplicates = True to bypass this check. \n") - return - - endpoint = "/api/v1/project/" + self.project_string_id + \ - "/label/new" - - response = self.session.post(self.host + endpoint, - json = label) - - data = response.json() - - if data["log"]["success"] == True: - if print_success is True: - print("New label success") - else: - raise Exception(data["log"]["errors"]) \ No newline at end of file +def label_new(self, + label: dict, + schema_id: int = None, + allow_duplicates: bool = False, + print_success: bool = True): + """ + + Arguments + self, + label_list, a list of label strings + ignore_duplicates, bool + print_success, bool + + Expects + + Returns + + """ + if schema_id is None: + schema = self.get_default_label_schema() + if schema is not None: + schema_id = schema.get('id') + + # Check if already exists + name = label.get('name', None) + if not name: + raise Exception("Please provide a key of name with a value of label") + label['schema_id'] = schema_id + if allow_duplicates is False: + + label_file_id = self.name_to_file_id.get(name, None) + + if label_file_id: + warnings.warn("\n\n '" + name + "' label already exists and was skipped." + \ + "\n Set allow_duplicates = True to bypass this check. \n") + return + + endpoint = "/api/v1/project/" + self.project_string_id + \ + "/label/new" + + response = self.session.post(self.host + endpoint, + json = label) + + data = response.json() + self.get_label_file_dict() + if data["log"]["success"] == True: + if print_success is True: + print("New label success") + else: + raise Exception(data["log"]["error"])