diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 59b9312..a01522c 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9, '3.10'] + python-version: [3.8, 3.9, '3.10', 3.11] steps: - uses: actions/checkout@v3 @@ -28,7 +28,7 @@ jobs: python -m pip install --upgrade pip pip install flake8 pytest pytest-astropy pytest-cov sphinx-astropy codecov # install package and requirements - pip install . + pip install ".[all]" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..24196a5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +# To install: +# pip install pre-commit +# pre-commit install --allow-missing-config +repos: +- repo: https://github.com/akaihola/darker + rev: 1.7.2 + hooks: + - id: darker \ No newline at end of file diff --git a/astrodbkit2/__init__.py b/astrodbkit2/__init__.py index 672376b..c59abce 100644 --- a/astrodbkit2/__init__.py +++ b/astrodbkit2/__init__.py @@ -6,7 +6,7 @@ # from ._astropy_init import * # noqa # ---------------------------------------------------------------------------- -__all__ = ['__version__'] +__all__ = ["__version__"] # from .example_mod import * # noqa # Then you can be explicit to control what ends up in the namespace, @@ -17,15 +17,26 @@ try: from .version import version as __version__ except ImportError: - __version__ = '' + __version__ = "" # Global variables # These describe the various database tables and their links -REFERENCE_TABLES = ['Publications', 'Telescopes', 'Instruments', 'Modes', 'Filters', 'PhotometryFilters', - 'Citations', 'References', 'Versions', 'Parameters', 'Regimes'] +REFERENCE_TABLES = [ + "Publications", + "Telescopes", + "Instruments", + "Modes", + "Filters", + "PhotometryFilters", + "Citations", + "References", + "Versions", + "Parameters", + "Regimes", +] # REFERENCE_TABLES is a list of tables that do not link to the primary table. # These are treated separately from the other data tables that are all assumed to be linked to the primary table. -PRIMARY_TABLE = 'Sources' # the primary table used for storing objects -PRIMARY_TABLE_KEY = 'source' # the name of the primary key in the primary table; this is used for joining tables -FOREIGN_KEY = 'source' # the name of the foreign key in other tables that refer back to the primary +PRIMARY_TABLE = "Sources" # the primary table used for storing objects +PRIMARY_TABLE_KEY = "source" # the name of the primary key in the primary table; this is used for joining tables +FOREIGN_KEY = "source" # the name of the foreign key in other tables that refer back to the primary diff --git a/astrodbkit2/astrodb.py b/astrodbkit2/astrodb.py index 6bc1796..c367485 100644 --- a/astrodbkit2/astrodb.py +++ b/astrodbkit2/astrodb.py @@ -1,6 +1,6 @@ -# Main database handler code +"""Main database handler code""" -__all__ = ['__version__', 'Database', 'or_', 'and_', 'create_database'] +__all__ = ["__version__", "Database", "or_", "and_", "create_database"] import os import json @@ -24,25 +24,28 @@ try: from .version import version as __version__ except ImportError: - __version__ = '' + __version__ = "" -# pylint: disable=dangerous-default-value, too-many-arguments, trailing-whitespace +# pylint: disable=dangerous-default-value, too-many-arguments, trailing-whitespace, abstract-method # For SQLAlchemy ORM Declarative mapping -# User created schema should import and use astrodb.Base so that +# User created schema should import and use astrodb.Base so that # create_database can properly handle them Base = declarative_base() class AstrodbQuery(Query): - # Subclassing the Query class to add more functionality. - # See: https://stackoverflow.com/questions/15936111/sqlalchemy-can-you-add-custom-methods-to-the-query-object - def _make_astropy(self): + """Subclassing the Query class to add more functionality. + See: https://stackoverflow.com/questions/15936111/sqlalchemy-can-you-add-custom-methods-to-the-query-object + """ + + def _make_astropy(self, **kwargs): + """Helper method to convert query results to an Astropy Table""" temp = self.all() if len(temp) > 0: - t = AstropyTable(rows=temp, names=temp[0]._fields) + t = AstropyTable(rows=temp, names=temp[0]._fields, **kwargs) else: - t = AstropyTable(temp) + t = AstropyTable(temp, **kwargs) return t def astropy(self, spectra=None, spectra_format=None, **kwargs): @@ -62,7 +65,7 @@ def astropy(self, spectra=None, spectra_format=None, **kwargs): Table output of query """ - t = self._make_astropy() + t = self._make_astropy(**kwargs) # Apply spectra conversion if spectra is not None: @@ -75,7 +78,7 @@ def astropy(self, spectra=None, spectra_format=None, **kwargs): return t def table(self, *args, **kwargs): - # Alternative for getting astropy Table + """Alternative method for getting astropy Table""" return self.astropy(*args, **kwargs) def pandas(self, spectra=None, spectra_format=None, **kwargs): @@ -96,7 +99,7 @@ def pandas(self, spectra=None, spectra_format=None, **kwargs): """ # Relying on astropy to convert to pandas for simplicity as that handles the column names - df = self._make_astropy().to_pandas() + df = self._make_astropy(**kwargs).to_pandas() # Apply spectra conversion if spectra is not None: @@ -108,7 +111,7 @@ def pandas(self, spectra=None, spectra_format=None, **kwargs): return df - def spectra(self, spectra=['spectrum', 'access_url'], fmt='astropy', **kwargs): + def spectra(self, spectra=["spectrum", "access_url"], fmt="astropy", **kwargs): """ Convenience method fo that uses default column name for spectra conversion @@ -119,7 +122,7 @@ def spectra(self, spectra=['spectrum', 'access_url'], fmt='astropy', **kwargs): fmt : str Output format (Default: astropy) """ - if fmt == 'pandas': + if fmt == "pandas": return self.pandas(spectra=spectra, **kwargs) else: return self.astropy(spectra=spectra, **kwargs) @@ -154,13 +157,13 @@ def load_connection(connection_string, sqlite_foreign=True, base=None, connectio engine = create_engine(connection_string, connect_args=connection_arguments) if not base: - base = declarative_base() + base = declarative_base() base.metadata.bind = engine - Session = sessionmaker(bind=engine, query_cls=AstrodbQuery) + Session = sessionmaker(bind=engine, query_cls=AstrodbQuery) # pylint: disable=invalid-name session = Session() # Enable foreign key checks in SQLite - if 'sqlite' in connection_string and sqlite_foreign: + if "sqlite" in connection_string and sqlite_foreign: set_sqlite() # elif 'postgresql' in connection_string: # # Set up schema in postgres (must be lower case?) @@ -172,7 +175,9 @@ def load_connection(connection_string, sqlite_foreign=True, base=None, connectio def set_sqlite(): - # Special overrides when using SQLite + """Special overrides when using SQLite""" + # pylint: disable=unused-argument + @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Enable foreign key checking in SQLite @@ -201,8 +206,9 @@ def create_database(connection_string, drop_tables=False): return session, base, engine -def copy_database_schema(source_connection_string, destination_connection_string, sqlite_foreign=False, - ignore_tables=[], copy_data=False): +def copy_database_schema( + source_connection_string, destination_connection_string, sqlite_foreign=False, ignore_tables=[], copy_data=False +): """ Copy a database schema (ie, all tables and columns) from one database to another Adapted from https://gist.github.com/pawl/9935333 @@ -237,7 +243,7 @@ def copy_database_schema(source_connection_string, destination_connection_string # Copy schema and create newTable from oldTable for column in src_metadata.tables[table.name].columns: - dest_table.append_column(column._copy()) + dest_table.append_column(column._copy()) # pylint: disable=protected-access dest_table.create(bind=dest_engine) # Copy data, row by row @@ -255,14 +261,19 @@ def copy_database_schema(source_connection_string, destination_connection_string class Database: - def __init__(self, connection_string, - reference_tables=REFERENCE_TABLES, - primary_table=PRIMARY_TABLE, - primary_table_key=PRIMARY_TABLE_KEY, - foreign_key=FOREIGN_KEY, - column_type_overrides={}, - sqlite_foreign=True, - connection_arguments={}): + """Database handler class""" + + def __init__( + self, + connection_string, + reference_tables=REFERENCE_TABLES, + primary_table=PRIMARY_TABLE, + primary_table_key=PRIMARY_TABLE_KEY, + foreign_key=FOREIGN_KEY, + column_type_overrides={}, + sqlite_foreign=True, + connection_arguments={}, + ): """ Wrapper for database calls and utility functions @@ -289,11 +300,12 @@ def __init__(self, connection_string, Additional connection arguments, like {'check_same_thread': False}. Default: {} """ - if connection_string == 'sqlite://': + if connection_string == "sqlite://": self.session, self.base, self.engine = create_database(connection_string) else: - self.session, self.base, self.engine = load_connection(connection_string, sqlite_foreign=sqlite_foreign, - connection_arguments=connection_arguments) + self.session, self.base, self.engine = load_connection( + connection_string, sqlite_foreign=sqlite_foreign, connection_arguments=connection_arguments + ) # Convenience methods self.query = self.session.query @@ -311,9 +323,11 @@ def __init__(self, connection_string, self._foreign_key = foreign_key if len(self.metadata.tables) == 0: - print('Database empty. Import schema (eg, from astrodbkit.schema_example import *) ' - 'and then run create_database()') - raise RuntimeError('Create database first.') + print( + "Database empty. Import schema (eg, from astrodbkit.schema_example import *) " + "and then run create_database()" + ) + raise RuntimeError("Create database first.") # Set tables as explicit attributes of this class for table in self.metadata.tables: @@ -322,19 +336,19 @@ def __init__(self, connection_string, # If column overrides are provided, this will set the types to whatever the user provided if len(column_type_overrides) > 0: for k, v in column_type_overrides.items(): - tab, col = k.split('.') + tab, col = k.split(".") self.metadata.tables[tab].columns[col].type = v # Generic methods @staticmethod def _handle_format(temp, fmt): # Internal method to handle SQLAlchemy output and format it - if fmt.lower() in ('astropy', 'table'): + if fmt.lower() in ("astropy", "table"): if len(temp) > 0: results = AstropyTable(rows=temp, names=temp[0]._fields) else: results = AstropyTable(temp) - elif fmt.lower() == 'pandas': + elif fmt.lower() == "pandas": if len(temp) > 0: results = pd.DataFrame(temp, columns=temp[0]._fields) else: @@ -426,10 +440,17 @@ def inventory(self, name, pretty_print=False): return data_dict # Text query methods - @deprecated_alias(format='fmt') - def search_object(self, name, output_table=None, resolve_simbad=False, - table_names={'Sources': ['source', 'shortname'], 'Names': ['other_name']}, - fmt='table', fuzzy_search=True, verbose=True): + @deprecated_alias(format="fmt") + def search_object( + self, + name, + output_table=None, + resolve_simbad=False, + table_names={"Sources": ["source", "shortname"], "Names": ["other_name"]}, + fmt="table", + fuzzy_search=True, + verbose=True, + ): """ Query the database for the object specified. By default will return the primary table, but this can be specified. Users can also request to resolve the object name via Simbad and query against @@ -462,7 +483,7 @@ def search_object(self, name, output_table=None, resolve_simbad=False, if output_table is None: output_table = self._primary_table if output_table not in self.metadata.tables: - raise RuntimeError(f'Table {output_table} is not in the database') + raise RuntimeError(f"Table {output_table} is not in the database") match_column = self._foreign_key if output_table == self._primary_table: @@ -473,7 +494,7 @@ def search_object(self, name, output_table=None, resolve_simbad=False, simbad_names = get_simbad_names(name, verbose=verbose) name = list(set(simbad_names + [name])) if verbose: - print(f'Including Simbad names, searching for: {name}') + print(f"Including Simbad names, searching for: {name}") # Turn name into a list if not isinstance(name, list): @@ -482,7 +503,7 @@ def search_object(self, name, output_table=None, resolve_simbad=False, # Verify provided tables exist in database for k in table_names.keys(): if k not in self.metadata.tables: - raise RuntimeError(f'Table {k} is not in the database') + raise RuntimeError(f"Table {k} is not in the database") # Get source for objects that match the provided names # The following will build the filters required to query all specified tables @@ -493,11 +514,9 @@ def search_object(self, name, output_table=None, resolve_simbad=False, for k, col_list in table_names.items(): for v in col_list: if fuzzy_search: - filters = [self.metadata.tables[k].columns[v].ilike(f'%{n}%') - for n in name] + filters = [self.metadata.tables[k].columns[v].ilike(f"%{n}%") for n in name] else: - filters = [self.metadata.tables[k].columns[v].ilike(f'{n}') - for n in name] + filters = [self.metadata.tables[k].columns[v].ilike(f"{n}") for n in name] # Column to be returned if k == self._primary_table: @@ -505,22 +524,21 @@ def search_object(self, name, output_table=None, resolve_simbad=False, else: output_to_match = self.metadata.tables[k].columns[self._foreign_key] - temp = self.query(output_to_match).\ - filter(or_(*filters)).\ - distinct().\ - all() + temp = self.query(output_to_match).filter(or_(*filters)).distinct().all() matched_names += [s[0] for s in temp] # Join the matched sources with the desired table - temp = self.query(self.metadata.tables[output_table]).\ - filter(self.metadata.tables[output_table].columns[match_column].in_(matched_names)).\ - all() + temp = ( + self.query(self.metadata.tables[output_table]) + .filter(self.metadata.tables[output_table].columns[match_column].in_(matched_names)) + .all() + ) results = self._handle_format(temp, fmt) return results - def search_string(self, value, fmt='table', fuzzy_search=True, verbose=True): + def search_string(self, value, fmt="table", fuzzy_search=True, verbose=True): """ Search an abitrary string across all string columns in the full database @@ -545,24 +563,24 @@ def search_string(self, value, fmt='table', fuzzy_search=True, verbose=True): for table in self.metadata.tables: # Gather only string-type columns columns = self.metadata.tables[table].columns - col_list = [c for c in columns - if isinstance(c.type, sqlalchemy_types.String) - or isinstance(c.type, sqlalchemy_types.Text) - or isinstance(c.type, sqlalchemy_types.Unicode)] + col_list = [ + c + for c in columns + if isinstance(c.type, sqlalchemy_types.String) + or isinstance(c.type, sqlalchemy_types.Text) + or isinstance(c.type, sqlalchemy_types.Unicode) + ] # Construct filters to query for each string column filters = [] for c in col_list: if fuzzy_search: - filters += [c.ilike(f'%{value}%')] + filters += [c.ilike(f"%{value}%")] else: - filters += [c.ilike(f'{value}')] + filters += [c.ilike(f"{value}")] # Perform the actual query - temp = self.query(self.metadata.tables[table]). \ - filter(or_(*filters)). \ - distinct(). \ - all() + temp = self.query(self.metadata.tables[table]).filter(or_(*filters)).distinct().all() # Append results to dictionary output in specified format if len(temp) > 0: @@ -575,8 +593,8 @@ def search_string(self, value, fmt='table', fuzzy_search=True, verbose=True): return output_dict # General query methods - @deprecated_alias(format='fmt') - def sql_query(self, query, fmt='default'): + @deprecated_alias(format="fmt") + def sql_query(self, query, fmt="default"): """ Wrapper for a direct SQL query. @@ -597,8 +615,18 @@ def sql_query(self, query, fmt='default'): return self._handle_format(temp, fmt) - def query_region(self, target_coords, radius=Quantity(10, unit='arcsec'), output_table=None, fmt='table', - coordinate_table=None, ra_col='ra', dec_col='dec', frame='icrs', unit='deg'): + def query_region( + self, + target_coords, + radius=Quantity(10, unit="arcsec"), + output_table=None, + fmt="table", + coordinate_table=None, + ra_col="ra", + dec_col="dec", + frame="icrs", + unit="deg", + ): """ Perform a cone search of the given coordinates and return the specified output table. @@ -633,11 +661,11 @@ def query_region(self, target_coords, radius=Quantity(10, unit='arcsec'), output if output_table is None: output_table = self._primary_table if output_table not in self.metadata.tables: - raise RuntimeError(f'Table {output_table} is not in the database') + raise RuntimeError(f"Table {output_table} is not in the database") # Radius conversion if not isinstance(radius, Quantity): - radius = Quantity(radius, unit='arcsec') + radius = Quantity(radius, unit="arcsec") # Get the column name to use for matching match_column = self._foreign_key @@ -648,19 +676,19 @@ def query_region(self, target_coords, radius=Quantity(10, unit='arcsec'), output if coordinate_table is None: coordinate_table = self._primary_table if coordinate_table not in self.metadata.tables: - raise RuntimeError(f'Table {coordinate_table} is not in the database') + raise RuntimeError(f"Table {coordinate_table} is not in the database") coordinate_match_column = self._foreign_key if coordinate_table == self._primary_table: coordinate_match_column = self._primary_table_key # This is adapted from the original astrodbkit code df = self.query(self.metadata.tables[coordinate_table]).pandas() - df[['ra', 'dec']] = df[[ra_col, dec_col]].apply(pd.to_numeric) # convert everything to floats - mask = df['ra'].isnull() + df[["ra", "dec"]] = df[[ra_col, dec_col]].apply(pd.to_numeric) # convert everything to floats + mask = df["ra"].isnull() df = df[~mask] # Native use of astropy SkyCoord objects here - coord_list = SkyCoord(df['ra'].tolist(), df['dec'].tolist(), frame=frame, unit=unit) + coord_list = SkyCoord(df["ra"].tolist(), df["dec"].tolist(), frame=frame, unit=unit) sep_list = coord_list.separation(target_coords) # sky separations for each db object against target position good = sep_list <= radius @@ -670,9 +698,11 @@ def query_region(self, target_coords, radius=Quantity(10, unit='arcsec'), output matched_list = [] # Join the matched sources with the desired table - temp = self.query(self.metadata.tables[output_table]). \ - filter(self.metadata.tables[output_table].columns[match_column].in_(matched_list)). \ - all() + temp = ( + self.query(self.metadata.tables[output_table]) + .filter(self.metadata.tables[output_table].columns[match_column].in_(matched_list)) + .all() + ) results = self._handle_format(temp, fmt) return results @@ -691,6 +721,8 @@ def save_json(self, name, directory): Name of directory in which to save the output JSON """ + # pylint: disable=unnecessary-dunder-call + if isinstance(name, str): source_name = str(name) data = self.inventory(name) @@ -699,8 +731,8 @@ def save_json(self, name, directory): data = self.inventory(name.__getattribute__(self._primary_table_key)) # Clean up spaces and other special characters - filename = source_name.lower().replace(' ', '_').replace('*', '').strip() + '.json' - with open(os.path.join(directory, filename), 'w') as f: + filename = source_name.lower().replace(" ", "_").replace("*", "").strip() + ".json" + with open(os.path.join(directory, filename), "w", encoding="utf-8") as f: f.write(json.dumps(data, indent=4, default=json_serializer)) def save_reference_table(self, table, directory): @@ -716,9 +748,9 @@ def save_reference_table(self, table, directory): results = self.session.query(self.metadata.tables[table]).all() data = [row._asdict() for row in results] - filename = table + '.json' + filename = table + ".json" if len(data) > 0: - with open(os.path.join(directory, filename), 'w') as f: + with open(os.path.join(directory, filename), "w", encoding="utf-8") as f: f.write(json.dumps(data, indent=4, default=json_serializer)) def save_database(self, directory, clear_first=True): @@ -737,7 +769,7 @@ def save_database(self, directory, clear_first=True): # Clear existing files first from that directory if clear_first: - print('Clearing existing JSON files...') + print("Clearing existing JSON files...") for filename in os.listdir(directory): os.remove(os.path.join(directory, filename)) @@ -754,7 +786,7 @@ def save_database(self, directory, clear_first=True): self.save_json(row, directory) # Object input methods - def add_table_data(self, data, table, fmt='csv'): + def add_table_data(self, data, table, fmt="csv"): """ Method to insert data into the database. Column names in the file must match those of the database table. Additional columns in the supplied table are ignored. @@ -775,14 +807,14 @@ def add_table_data(self, data, table, fmt='csv'): Data format. Default: csv """ - if fmt.lower() == 'csv': + if fmt.lower() == "csv": df = pd.read_csv(data) - elif fmt.lower() == 'astropy': + elif fmt.lower() == "astropy": df = data.to_pandas() - elif fmt.lower() == 'pandas': + elif fmt.lower() == "pandas": df = data.copy() else: - raise RuntimeError(f'Unrecognized format {fmt}') + raise RuntimeError(f"Unrecognized format {fmt}") # Foreign key constraints will prevent inserts of missing sources, # but for clarity we'll check first and exit if there are missing sources @@ -792,9 +824,9 @@ def add_table_data(self, data, table, fmt='csv'): matched_sources = self.query(primary_column).filter(primary_column.in_(source_list)).all() missing_sources = np.setdiff1d(source_list, matched_sources) if len(missing_sources) > 0: - print(f'{len(missing_sources)} missing source(s):') + print(f"{len(missing_sources)} missing source(s):") print(missing_sources) - raise RuntimeError(f'There are missing entries in {self._primary_table} table. These must exist first.') + raise RuntimeError(f"There are missing entries in {self._primary_table} table. These must exist first.") # Convert format for SQLAlchemy data = [row.to_dict() for _, row in df.iterrows()] @@ -821,14 +853,15 @@ def load_table(self, table, directory, verbose=False): Flag to enable diagnostic messages """ - filename = os.path.join(directory, table+'.json') + filename = os.path.join(directory, table + ".json") if os.path.exists(filename): - with open(filename, 'r') as f: + with open(filename, "r", encoding="utf-8") as f: data = json.load(f) with self.engine.begin() as conn: conn.execute(self.metadata.tables[table].insert().values(data)) else: - if verbose: print(f'{table}.json not found.') + if verbose: + print(f"{table}.json not found.") def load_json(self, filename): """ @@ -840,7 +873,7 @@ def load_json(self, filename): Name of directory containing the JSON file """ - with open(filename, 'r') as f: + with open(filename, "r", encoding="utf-8") as f: data = json.load(f, object_hook=datetime_json_parser) # Loop through the dictionary, adding data to the database. @@ -874,32 +907,36 @@ def load_database(self, directory, verbose=False): # Clear existing database contents # reversed(sorted_tables) can help ensure that foreign key dependencies are taken care of first for table in reversed(self.metadata.sorted_tables): - if verbose: print(f'Deleting {table.name} table') + if verbose: + print(f"Deleting {table.name} table") with self.engine.begin() as conn: conn.execute(self.metadata.tables[table.name].delete()) # Load reference tables first for table in self._reference_tables: - if verbose: print(f'Loading {table} table') + if verbose: + print(f"Loading {table} table") self.load_table(table, directory, verbose=verbose) # Load object data - if verbose: print('Loading object tables') + if verbose: + print("Loading object tables") for file in tqdm(os.listdir(directory)): # Skip reference tables - core_name = file.replace('.json', '') + core_name = file.replace(".json", "") if core_name in self._reference_tables: continue # Skip non-JSON files or hidden files - if not file.endswith('.json') or file.startswith('.'): + if not file.endswith(".json") or file.startswith("."): continue self.load_json(os.path.join(directory, file)) def dump_sqlite(self, database_name): - if self.engine.url.drivername == 'sqlite': + """Output database as a sqlite file""" + if self.engine.url.drivername == "sqlite": destconn = sqlite3.connect(database_name) self.engine.raw_connection().backup(destconn) else: - print('AstrodbKit2: dump_sqlite not available for non-sqlite databases') + print("AstrodbKit2: dump_sqlite not available for non-sqlite databases") diff --git a/astrodbkit2/schema_example.py b/astrodbkit2/schema_example.py index cfb7658..0233397 100644 --- a/astrodbkit2/schema_example.py +++ b/astrodbkit2/schema_example.py @@ -1,9 +1,11 @@ -# Example schema for part of the SIMPLE database +"""Example schema for part of the SIMPLE database""" +# pylint: disable=unused-argument, unused-import + +import enum import sqlalchemy as sa from sqlalchemy import Boolean, Column, Float, ForeignKey, Integer, String, BigInteger, Enum, Date, DateTime from sqlalchemy.orm import validates -import enum from astrodbkit2.astrodb import Base from astrodbkit2.views import view @@ -23,12 +25,14 @@ class Publications(Base): class Telescopes(Base): + """Telescopes table""" __tablename__ = "Telescopes" name = Column(String(30), primary_key=True, nullable=False) reference = Column(String(30), ForeignKey("Publications.name", ondelete="cascade")) class Instruments(Base): + """Instruments table""" __tablename__ = "Instruments" name = Column(String(30), primary_key=True, nullable=False) reference = Column(String(30), ForeignKey("Publications.name", ondelete="cascade")) @@ -39,6 +43,7 @@ class Instruments(Base): class Regime(enum.Enum): """Enumeration for spectral type regime""" + # pylint: disable=invalid-name optical = "optical" infrared = "infrared" ultraviolet = "ultraviolet" @@ -60,24 +65,28 @@ class Sources(Base): @validates("ra") def validate_ra(self, key, value): + """Ensure RA is within bounds""" if value > 360 or value < 0: raise ValueError("RA not in allowed range (0..360)") return value - + @validates("dec") def validate_dec(self, key, value): + """Ensure Dec is within bounds""" if value > 90 or value < -90: raise ValueError("Dec not in allowed range (-90..90)") return value class Names(Base): + """Names table""" __tablename__ = "Names" source = Column(String(100), ForeignKey("Sources.source", ondelete="cascade"), nullable=False, primary_key=True) other_name = Column(String(100), primary_key=True, nullable=False) class Photometry(Base): + """Photometry table""" __tablename__ = "Photometry" source = Column( String(100), @@ -97,6 +106,7 @@ class Photometry(Base): class SpectralTypes(Base): + """SpectralTypes table""" __tablename__ = "SpectralTypes" source = Column( String(100), diff --git a/astrodbkit2/spectra.py b/astrodbkit2/spectra.py index 7e07ab5..761df07 100644 --- a/astrodbkit2/spectra.py +++ b/astrodbkit2/spectra.py @@ -13,14 +13,14 @@ # pylint: disable=no-member, unused-argument + def _identify_spex(filename): """ Check whether the given file is a SpeX data product. """ try: with fits.open(filename, memmap=False) as hdulist: - return 'spex' in hdulist[0].header['INSTRUME'].lower() and \ - 'irtf' in hdulist[0].header['TELESCOP'].lower() + return "spex" in hdulist[0].header["INSTRUME"].lower() and "irtf" in hdulist[0].header["TELESCOP"].lower() except Exception: # pylint: disable=broad-except, return False @@ -34,17 +34,17 @@ def identify_spex_prism(origin, *args, **kwargs): is_spex = _identify_spex(args[0]) if is_spex: with fits.open(args[0], memmap=False) as hdulist: - return (isinstance(args[0], str) and - os.path.splitext(args[0].lower())[1] == '.fits' and - is_spex - and ('lowres' in hdulist[0].header['GRAT'].lower() or - 'prism' in hdulist[0].header['GRAT'].lower()) - ) + return ( + isinstance(args[0], str) + and os.path.splitext(args[0].lower())[1] == ".fits" + and is_spex + and ("lowres" in hdulist[0].header["GRAT"].lower() or "prism" in hdulist[0].header["GRAT"].lower()) + ) else: return is_spex -@data_loader("Spex Prism", identifier=identify_spex_prism, extensions=['fits'], dtype=Spectrum1D) +@data_loader("Spex Prism", identifier=identify_spex_prism, extensions=["fits"], dtype=Spectrum1D) def spex_prism_loader(filename, **kwargs): """Open a SpeX Prism file and convert it to a Spectrum1D object""" @@ -55,12 +55,12 @@ def spex_prism_loader(filename, **kwargs): # Handle missing/incorrect units try: - flux_unit = header['YUNITS'].replace('ergs', 'erg ').strip() - wave_unit = header['XUNITS'].replace('Microns', 'um') + flux_unit = header["YUNITS"].replace("ergs", "erg ").strip() + wave_unit = header["XUNITS"].replace("Microns", "um") except (KeyError, ValueError): # For now, assume some default units - flux_unit = 'erg' - wave_unit = 'um' + flux_unit = "erg" + wave_unit = "um" wave, data = tab[0] * Unit(wave_unit), tab[1] * Unit(flux_unit) @@ -69,7 +69,7 @@ def spex_prism_loader(filename, **kwargs): else: uncertainty = None - meta = {'header': header} + meta = {"header": header} return Spectrum1D(flux=data, spectral_axis=wave, uncertainty=uncertainty, meta=meta) @@ -78,21 +78,21 @@ def identify_wcs1d_multispec(origin, *args, **kwargs): """ Identifier for WCS1D multispec """ - hdu = kwargs.get('hdu', 0) + hdu = kwargs.get("hdu", 0) # Check if number of axes is one and dimension of WCS is greater than one with read_fileobj_or_hdulist(*args, **kwargs) as hdulist: - return (hdulist[hdu].header.get('WCSDIM', 1) > 1 and - hdulist[hdu].header['NAXIS'] > 1 and - 'WAT0_001' in hdulist[hdu].header and - hdulist[hdu].header.get('WCSDIM', 1) == hdulist[hdu].header['NAXIS'] and - 'LINEAR' in hdulist[hdu].header.get('CTYPE1', '')) + return ( + hdulist[hdu].header.get("WCSDIM", 1) > 1 + and hdulist[hdu].header["NAXIS"] > 1 + and "WAT0_001" in hdulist[hdu].header + and hdulist[hdu].header.get("WCSDIM", 1) == hdulist[hdu].header["NAXIS"] + and "LINEAR" in hdulist[hdu].header.get("CTYPE1", "") + ) -@data_loader("wcs1d-multispec", identifier=identify_wcs1d_multispec, extensions=['fits'], - dtype=Spectrum1D, priority=10) -def wcs1d_multispec_loader(file_obj, flux_unit=None, - hdu=0, verbose=False, **kwargs): +@data_loader("wcs1d-multispec", identifier=identify_wcs1d_multispec, extensions=["fits"], dtype=Spectrum1D, priority=10) +def wcs1d_multispec_loader(file_obj, flux_unit=None, hdu=0, verbose=False, **kwargs): """ Loader for multiextension spectra as wcs1d. Adapted from wcs1d_fits_loader @@ -122,27 +122,27 @@ def wcs1d_multispec_loader(file_obj, flux_unit=None, wcs = WCS(header) # Load data, convert units if BUNIT and flux_unit is provided and not the same - if 'BUNIT' in header: - data = u.Quantity(hdulist[hdu].data, unit=header['BUNIT']) + if "BUNIT" in header: + data = u.Quantity(hdulist[hdu].data, unit=header["BUNIT"]) if u.A in data.unit.bases: - data = data * u.A/u.AA # convert ampere to Angroms + data = data * u.A / u.AA # convert ampere to Angroms if flux_unit is not None: data = data.to(flux_unit) else: data = u.Quantity(hdulist[hdu].data, unit=flux_unit) - if wcs.wcs.cunit[0] == '' and 'WAT1_001' in header: + if wcs.wcs.cunit[0] == "" and "WAT1_001" in header: # Try to extract from IRAF-style card or use Angstrom as default. - wat_dict = dict((rec.split('=') for rec in header['WAT1_001'].split())) - unit = wat_dict.get('units', 'Angstrom') + wat_dict = dict((rec.split("=") for rec in header["WAT1_001"].split())) + unit = wat_dict.get("units", "Angstrom") if hasattr(u, unit): wcs.wcs.cunit[0] = unit else: # try with unit name stripped of excess plural 's'... - wcs.wcs.cunit[0] = unit.rstrip('s') + wcs.wcs.cunit[0] = unit.rstrip("s") if verbose: print(f"Extracted spectral axis unit '{unit}' from 'WAT1_001'") - elif wcs.wcs.cunit[0] == '': - wcs.wcs.cunit[0] = 'Angstrom' + elif wcs.wcs.cunit[0] == "": + wcs.wcs.cunit[0] = "Angstrom" # Compatibility attribute for lookup_table (gwcs) WCS wcs.unit = tuple(wcs.wcs.cunit) @@ -153,11 +153,11 @@ def wcs1d_multispec_loader(file_obj, flux_unit=None, else: flux_data = data uncertainty = None - if 'NAXIS3' in header: - for i in range(header['NAXIS3']): - if 'spectrum' in header.get(f'BANDID{i+1}', ''): + if "NAXIS3" in header: + for i in range(header["NAXIS3"]): + if "spectrum" in header.get(f"BANDID{i+1}", ""): flux_data = data[i] - if 'sigma' in header.get(f'BANDID{i+1}', ''): + if "sigma" in header.get(f"BANDID{i+1}", ""): uncertainty = data[i] # Reshape arrays if needed @@ -171,42 +171,39 @@ def wcs1d_multispec_loader(file_obj, flux_unit=None, uncertainty = StdDevUncertainty(uncertainty) # Manually generate spectral axis - pixels = [[i] + [0]*(wcs.naxis-1) for i in range(wcs.pixel_shape[0])] + pixels = [[i] + [0] * (wcs.naxis - 1) for i in range(wcs.pixel_shape[0])] spectral_axis = [i[0] for i in wcs.all_pix2world(pixels, 0)] * wcs.wcs.cunit[0] # Store header as metadata information - meta = {'header': header} + meta = {"header": header} - return Spectrum1D(flux=flux_data, spectral_axis=spectral_axis, uncertainty=uncertainty, - meta=meta) + return Spectrum1D(flux=flux_data, spectral_axis=spectral_axis, uncertainty=uncertainty, meta=meta) -def load_spectrum(filename: str, - spectra_format: str=None, - raise_error: bool=False): +def load_spectrum(filename: str, spectra_format: str = None, raise_error: bool = False): """Attempt to load the filename as a spectrum object - + Parameters ---------- filename Name of the file to read spectra_format - Optional file format, passed to Spectrum1D.read. + Optional file format, passed to Spectrum1D.read. In its absense Spectrum1D.read will attempt to determine the format. raise_error Boolean to control if a failure to read the spectrum should raise an error. """ # Convert filename if using environment variables - if filename.startswith('$'): + if filename.startswith("$"): partial_path, _ = os.path.split(filename) - while partial_path != '': + while partial_path != "": partial_path, envvar_name = os.path.split(partial_path) abs_path = os.getenv(envvar_name[1:]) if abs_path is not None: filename = filename.replace(envvar_name, abs_path) else: - print(f'Could not find environment variable {envvar_name}') + print(f"Could not find environment variable {envvar_name}") try: if spectra_format is not None: @@ -214,7 +211,7 @@ def load_spectrum(filename: str, else: spec1d = Spectrum1D.read(filename) except Exception as e: # pylint: disable=broad-except, invalid-name - msg = f'Error loading {filename}: {e}' + msg = f"Error loading {filename}: {e}" # Control whether an error should be explicitly raised if failing to read if raise_error: diff --git a/astrodbkit2/utils.py b/astrodbkit2/utils.py index f290645..7b1c16b 100644 --- a/astrodbkit2/utils.py +++ b/astrodbkit2/utils.py @@ -1,14 +1,13 @@ -# Utility functions for Astrodbkit2 +"""Utility functions for Astrodbkit2""" import re -import numpy as np import functools import warnings from datetime import datetime from decimal import Decimal from astroquery.simbad import Simbad -__all__ = ['json_serializer', 'get_simbad_names'] +__all__ = ["json_serializer", "get_simbad_names"] def deprecated_alias(**aliases): @@ -18,12 +17,15 @@ def deprecated_alias(**aliases): in order to handle deprecation of renamed columns To use: add @deprecated_alias(old_name='new_name') """ + def deco(f): @functools.wraps(f) def wrapper(*args, **kwargs): rename_kwargs(f.__name__, kwargs, aliases) return f(*args, **kwargs) + return wrapper + return deco @@ -32,17 +34,15 @@ def rename_kwargs(func_name, kwargs, aliases): for alias, new in aliases.items(): if alias in kwargs: if new in kwargs: - raise TypeError('{} received both {} and {}'.format( - func_name, alias, new)) - warnings.warn('{} is deprecated; use {}'.format(alias, new), - DeprecationWarning) + raise TypeError(f"{func_name} received both {alias} and {new}") + warnings.warn(f"{alias} is deprecated; use {new}", DeprecationWarning) kwargs[new] = kwargs.pop(alias) def json_serializer(obj): """Function describing how things should be serialized in JSON. Datetime objects are saved with datetime.isoformat(), Parameter class objects use clean_dict() - while all others use __dict__ """ + while all others use __dict__""" if isinstance(obj, datetime): return obj.isoformat() @@ -51,7 +51,7 @@ def json_serializer(obj): return float(obj) if isinstance(obj, bytes): - return obj.decode('utf-8') + return obj.decode("utf-8") return obj.__dict__ @@ -61,7 +61,7 @@ def datetime_json_parser(json_dict): This is required to get datetime objects into the database. Adapted from: https://stackoverflow.com/questions/8793448/how-to-convert-to-a-python-datetime-object-with-json-loads """ - for (key, value) in json_dict.items(): + for key, value in json_dict.items(): if isinstance(value, str): try: json_dict[key] = datetime.fromisoformat(value) @@ -90,14 +90,14 @@ def _name_formatter(name): name = re.sub(r"\s\s+", " ", name) # Clean up Simbad types - strings_to_delete = ['V* ', 'EM* ', 'NAME ', '** ', 'Cl* ', '* '] + strings_to_delete = ["V* ", "EM* ", "NAME ", "** ", "Cl* ", "* "] for pattern in strings_to_delete: - name = name.replace(pattern, '') + name = name.replace(pattern, "") name = name.strip() # Clean up 'hidden' names from Simbad - if 'HIDDEN' in name.upper(): + if "HIDDEN" in name.upper(): name = None return name @@ -121,9 +121,9 @@ def get_simbad_names(name, verbose=False): t = Simbad.query_objectids(name) if t is not None and len(t) > 0: - temp = [_name_formatter(s) for s in t['ID'].tolist()] - return [s for s in temp if s is not None and s != ''] + temp = [_name_formatter(s) for s in t["ID"].tolist()] + return [s for s in temp if s is not None and s != ""] else: if verbose: - print(f'No Simbad match for {name}') + print(f"No Simbad match for {name}") return [name] diff --git a/astrodbkit2/views.py b/astrodbkit2/views.py index fe2d527..dd8d40e 100644 --- a/astrodbkit2/views.py +++ b/astrodbkit2/views.py @@ -1,11 +1,14 @@ -# Logic to implement and set up views in SQLAlchemy -# Adapted from https://github.com/sqlalchemy/sqlalchemy/wiki/Views +"""Logic to implement and set up views in SQLAlchemy +Adapted from https://github.com/sqlalchemy/sqlalchemy/wiki/Views +""" import sqlalchemy as sa from sqlalchemy.ext import compiler from sqlalchemy.schema import DDLElement from sqlalchemy.sql import table +# pylint: disable=abstract-method, missing-function-docstring, unused-argument, redefined-outer-name, protected-access, missing-class-docstring + class CreateView(DDLElement): def __init__(self, name, selectable): @@ -20,6 +23,7 @@ def __init__(self, name): @compiler.compiles(CreateView) def _create_view(element, compiler, **kw): + # pylint: disable=consider-using-f-string return "CREATE VIEW %s AS %s" % ( element.name, compiler.sql_compiler.process(element.selectable, literal_binds=True), @@ -28,6 +32,7 @@ def _create_view(element, compiler, **kw): @compiler.compiles(DropView) def _drop_view(element, compiler, **kw): + # pylint: disable=consider-using-f-string return "DROP VIEW %s" % (element.name) @@ -42,16 +47,12 @@ def view_doesnt_exist(ddl, target, connection, **kw): def view(name, metadata, selectable): t = table(name) - t._columns._populate_separate_keys( - col._make_proxy(t) for col in selectable.selected_columns - ) + t._columns._populate_separate_keys(col._make_proxy(t) for col in selectable.selected_columns) sa.event.listen( metadata, "after_create", CreateView(name, selectable).execute_if(callable_=view_doesnt_exist), ) - sa.event.listen( - metadata, "before_drop", DropView(name).execute_if(callable_=view_exists) - ) + sa.event.listen(metadata, "before_drop", DropView(name).execute_if(callable_=view_exists)) return t diff --git a/pyproject.toml b/pyproject.toml index 980c279..f0492be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,3 +8,10 @@ requires = ["setuptools", "cython>=0.29.15"] build-backend = 'setuptools.build_meta' + +[tool.darker] +line-length = 120 + +[tool.black] +line-length = 120 +target-version = ["py39", "py310", "py311"] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 1b885cd..f4c7097 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,8 +30,15 @@ console_scripts = astropy-package-template-example = packagename.example_mod:main [options.extras_require] +all = + astrodbkit2[test, docs] test = + pytest + pytest-cov pytest-astropy + darker==1.7.2 + black==23.9.1 + pre-commit==3.4.0 docs = sphinx-astropy