diff --git a/.gitignore b/.gitignore index bfd0da0..f23365c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,99 @@ # Python build files -__pycache__ *.pyc -*.egg-info +*.clean +*~ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..b494496 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,12 @@ +# Automated testing. +language: python +python: + - "2.7" + - "3.2" + - "3.3" + - "3.4" + - "3.5" + - "3.6" + - "3.7-dev" +# command to run tests +script: python setup.py test diff --git a/README.md b/README.md index e5c4abf..8b59f48 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,15 @@ with multiple threads outside of the compile time options. This library was created to address this by bringing the responsibility of managing the threads to the python layer and is agnostic to the server setup of sqlite3. +[![Build Status](https://travis-ci.org/dashawn888/sqlite3worker.svg?branch=master)](https://travis-ci.org/dashawn888/sqlite3worker) + ## Install -Installation is via the usual ``setup.py`` method: +You can use pip: +```sh +sudo pip install sqlite3worker +``` +You can use setup.py: ```sh sudo python setup.py install ``` @@ -21,14 +27,14 @@ Alternatively one can use ``pip`` to install directly from the git repository without having to clone first: ```sh -sudo pip install git+https://github.com/palantir/sqlite3worker#egg=sqlite3worker +sudo pip install git+https://github.com/dashawn888/sqlite3worker#egg=sqlite3worker ``` One may also use ``pip`` to install on a per-user basis without requiring super-user permissions: ```sh -pip install --user git+https://github.com/palantir/sqlite3worker#egg=sqlite3worker +pip install --user git+https://github.com/dashawn888/sqlite3worker#egg=sqlite3worker ``` ## Example diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..8b1cd99 --- /dev/null +++ b/README.rst @@ -0,0 +1,96 @@ +Sqlite3Worker +============= + +A threadsafe sqlite worker. + +This library implements a thread pool pattern with sqlite3 being the +desired output. + +sqllite3 implementation lacks the ability to safely modify the sqlite3 +database with multiple threads outside of the compile time options. This +library was created to address this by bringing the responsibility of +managing the threads to the python layer and is agnostic to the server +setup of sqlite3. + +|Build Status| + +Install +------- + +You can use pip: + +.. code:: sh + + sudo pip install sqlite3worker + +You can use setup.py: + +.. code:: sh + + sudo python setup.py install + +Alternatively one can use ``pip`` to install directly from the git +repository without having to clone first: + +.. code:: sh + + sudo pip install git+https://github.com/dashawn888/sqlite3worker#egg=sqlite3worker + +One may also use ``pip`` to install on a per-user basis without +requiring super-user permissions: + +.. code:: sh + + pip install --user git+https://github.com/dashawn888/sqlite3worker#egg=sqlite3worker + +Example +------- + +.. code:: python + + from sqlite3worker import Sqlite3Worker + + sql_worker = Sqlite3Worker("/tmp/test.sqlite") + sql_worker.execute("CREATE TABLE tester (timestamp DATETIME, uuid TEXT)") + sql_worker.execute("INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow")) + sql_worker.execute("INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog")) + + results = sql_worker.execute("SELECT * from tester") + for timestamp, uuid in results: + print(timestamp, uuid) + + sql_worker.close() + +When to use sqlite3worker +------------------------- + +If you have multiple threads all needing to write to a sqlite3 database +this library will serialize the sqlite3 write requests. + +When NOT to use sqlite3worker +----------------------------- + +If your code DOES NOT use multiple threads then you don't need to use a +thread safe sqlite3 implementation. + +If you need multiple applications to write to a sqlite3 db then +sqlite3worker will not protect you from corrupting the data. + +Internals +--------- + +The library creates a queue to manage multiple queries sent to the +database. Instead of directly calling the sqlite3 interface, you will +call the Sqlite3Worker which inserts your query into a Queue.Queue() +object. The queries are processed in the order that they are inserted +into the queue (first in, first out). In order to ensure that the +multiple threads are managed in the same queue, you will need to pass +the same Sqlite3Worker object to each thread. + +Python docs for sqlite3 +----------------------- + +https://docs.python.org/2/library/sqlite3.html + +.. |Build Status| image:: https://travis-ci.org/dashawn888/sqlite3worker.svg?branch=master + :target: https://travis-ci.org/dashawn888/sqlite3worker diff --git a/__init__.py b/__init__.py index 8bf5c9e..eb37eab 100644 --- a/__init__.py +++ b/__init__.py @@ -21,7 +21,13 @@ """Init.""" __author__ = "Shawn Lee" -__email__ = "shawnl@palantir.com" +__email__ = "dashawn@gmail.com" __license__ = "MIT" +__version__ = "1.1.7" -from sqlite3worker import Sqlite3Worker +try: + # Python 2 + from sqlite3worker import Sqlite3Worker +except ImportError: + # Python 3 + from sqlite3worker.sqlite3worker import Sqlite3Worker diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..b88034e --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +description-file = README.md diff --git a/setup.py b/setup.py index ab09d3f..fdb1ecb 100755 --- a/setup.py +++ b/setup.py @@ -22,17 +22,41 @@ """Setup.""" __author__ = "Shawn Lee" -__email__ = "shawnl@palantir.com" +__email__ = "dashawn@gmail.com" __license__ = "MIT" from setuptools import setup +import os + +short_description="Thread safe sqlite3 interface", +long_description = short_description +if os.path.exists('README.rst'): + long_description = open('README.rst').read() setup( - name="sqlite3worker", - version="1.0", - description="Thread safe sqlite3 interface", - author="Shawn Lee", - author_email="shawnl@palantir.com", - packages=["sqlite3worker"], - package_dir={"sqlite3worker": "."}, - test_suite="sqlite3worker_test") + name="sqlite3worker", + version="1.1.7", + description=short_description, + author="Shawn Lee", + author_email="dashawn@gmail.com", + url="https://github.com/dashawn888/sqlite3worker", + packages=["sqlite3worker"], + package_dir={"sqlite3worker": "."}, + keywords=["sqlite", "sqlite3", "thread", "multithread", "multithreading"], + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.0", + "Programming Language :: Python :: 3.1", + "Programming Language :: Python :: 3.2", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Topic :: Database" + ], + test_suite="sqlite3worker_test") diff --git a/sqlite3worker.py b/sqlite3worker.py index 23eec87..a959ce4 100644 --- a/sqlite3worker.py +++ b/sqlite3worker.py @@ -25,173 +25,369 @@ __license__ = "MIT" import logging -try: - import queue as Queue # module re-named in Python 3 -except ImportError: - import Queue +import platform +import os import sqlite3 import threading import time -import uuid +try: + import queue as Queue # module re-named in Python 3 +except ImportError: # pragma: no cover + import Queue LOGGER = logging.getLogger('sqlite3worker') +OperationalError = sqlite3.OperationalError +ProgrammingError = sqlite3.ProgrammingError + +def dict_factory ( cursor, row ): + d = {} + for idx, col in enumerate ( cursor.description ): + d[col[0]] = row[idx] + return d + +# native sqlite3.Row doesn't like our proxy cursor class so we're going to substitute dict_factory instead which is almost the same thing +Row = dict_factory # sqlite3.Row + +class Frozen_object ( object ): + def __setattr__ ( self, key, value ): + if key not in dir ( self ): # prevent from accidentally creating new attributes + raise AttributeError ( '{!r} object has no attribute {!r}'.format ( type ( self ).__name__, key ) ) + super ( Frozen_object, self ).__setattr__ ( key, value ) + +class Sqlite3WorkerRequest ( Frozen_object ): + def execute ( self ): # pragma: no cover + raise NotImplementedError ( type ( self ).__name__ + '.execute()' ) + +class Sqlite3WorkerSetRowFactory ( Sqlite3WorkerRequest ): + thread = None + row_factory = None + + def __init__ ( self, thread, row_factory ): + self.thread = thread + self.row_factory = row_factory + + def execute ( self ): + self.thread._sqlite3_cursor.row_factory = self.row_factory + +class Sqlite3WorkerSetTextFactory ( Sqlite3WorkerRequest ): + thread = None + text_factory = None + + def __init__ ( self, thread, text_factory ): + self.thread = thread + self.text_factory = text_factory + + def execute ( self ): + self.thread._sqlite3_conn.text_factory = self.text_factory + +class Sqlite3WorkerExecute ( Sqlite3WorkerRequest ): + thread = None + query = None + values = None + results = None + + def __init__ ( self, thread, query, values ): + self.thread = thread + self.query = query + self.values = values + self.results = Queue.Queue() + + def execute ( self ): + LOGGER.debug ( "run execute: %s", self.query ) + cur = self.thread._sqlite3_cursor + try: + cur.execute ( self.query, self.values ) + result = ( cur.fetchall(), cur.description, cur.lastrowid ) + success = True + except Exception as err: + LOGGER.debug ( + "Sqlite3WorkerExecute.execute sending exception back to calling thread: {!r}".format ( err ) ) + result = err + success = False + self.results.put ( ( success, result ) ) + +class Sqlite3WorkerExecuteScript ( Sqlite3WorkerRequest ): + thread = None + query = None + results = None + + def __init__ ( self, thread, query ): + self.thread = thread + self.query = query + self.results = Queue.Queue() + + def execute ( self ): + LOGGER.debug ( "run executescript: %s", self.query ) + cur = self.thread._sqlite3_cursor + try: + cur.executescript ( self.query ) + result = ( cur.fetchall(), cur.description, cur.lastrowid ) + success = True + except Exception as err: + LOGGER.debug ( + "Sqlite3WorkerExecuteScript.execute sending exception back to calling thread: {!r}".format ( err ) ) + result = err + success = False + self.results.put ( ( success, result ) ) + +class Sqlite3WorkerCommit ( Sqlite3WorkerRequest ): + thread = None + + def __init__ ( self, thread ): + self.thread = thread + + def execute ( self ): + LOGGER.debug("run commit") + self.thread._sqlite3_conn.commit() + +class Sqlite3WorkerExit ( Exception, Sqlite3WorkerRequest ): + def execute ( self ): + raise self + +def normalize_file_name ( file_name ): + if file_name.lower() == ':memory:': + return ':memory:' + # lookup absolute path of file_name + file_name = os.path.abspath ( file_name ) + if platform.system() == 'Windows': + file_name = file_name.lower() # Windows filenames are not case-sensitive + return file_name + +class Sqlite3WorkerThread ( threading.Thread ): + _workers = None + _sqlite3_conn = None + _sqlite3_cursor = None + _sql_queue = None + _max_queue_size = None + + def __init__ ( self, file_name, max_queue_size, *args, **kwargs ): + super ( Sqlite3WorkerThread, self ).__init__ ( *args, **kwargs ) + self.daemon = True + self._workers = set() + self._sqlite3_conn = sqlite3.connect ( + file_name, check_same_thread=False, + #detect_types=sqlite3.PARSE_DECLTYPES + ) + self._sqlite3_cursor = self._sqlite3_conn.cursor() + self._sql_queue = Queue.Queue ( maxsize=max_queue_size ) + self._max_queue_size = max_queue_size + self.name = self.name.replace ( 'Thread-', 'Sqlite3WorkerThread-' ) + self.start() + + def run ( self ): + """Thread loop. + This is an infinite loop. The iter method calls self._sql_queue.get() + which blocks if there are not values in the queue. As soon as values + are placed into the queue the process will continue. + If many executes happen at once it will churn through them all before + calling commit() to speed things up by reducing the number of times + commit is called. + """ + LOGGER.debug("run: Thread started") + while True: + try: + x = self._sql_queue.get() + x.execute() + except Sqlite3WorkerExit as e: + if not self._sql_queue.empty(): # pragma: no cover ( TODO FIXME: come back to this ) + LOGGER.debug ( 'requeueing the exit event because there are unfinished actions' ) + self._sql_queue.put ( e ) # push the exit event to the end of the queue + continue + LOGGER.debug ( 'closing database connection' ) + self._sqlite3_cursor.close() + self._sqlite3_conn.commit() + self._sqlite3_conn.close() + LOGGER.debug ( 'exiting thread' ) + break + +class Sqlite3Worker ( Frozen_object ): + """Sqlite thread safe object. + Example: + from sqlite3worker import Sqlite3Worker + sql_worker = Sqlite3Worker("/tmp/test.sqlite") + sql_worker.execute( + "CREATE TABLE tester (timestamp DATETIME, uuid TEXT)") + sql_worker.execute( + "INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow")) + sql_worker.execute( + "INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog")) + sql_worker.execute("SELECT * from tester") + sql_worker.close() + """ + _file_name = None + _exit_set = False + _thread = None + + # class shared attributes + _threads = {} + _threads_lock = threading.Lock() + + def __init__ ( self, file_name, max_queue_size=100 ): + """Automatically starts the thread. + Args: + file_name: The name of the file. + max_queue_size: The max queries that will be queued. + """ + + self._file_name = normalize_file_name ( file_name ) + with self._threads_lock: + self._thread = self._threads.get ( self._file_name ) + if self._thread is None: + self._thread = Sqlite3WorkerThread ( self._file_name, max_queue_size ) + self._threads[self._file_name] = self._thread + if self._file_name != ':memory:': + self._threads[self._file_name] = self._thread + self._thread._workers.add ( self ) + + def close ( self ): + """If we're the last worker, close down the thread which closes the sqlite3 database file.""" + if self._exit_set: + raise ProgrammingError ( 'sqlite worker already closed' ) + self._exit_set = True + with self._threads_lock: + self._thread._workers.remove ( self ) + if not self._thread._workers: + self._thread._sql_queue.put ( Sqlite3WorkerExit(), timeout=5 ) + # wait for the thread to finish what it's doing and shut down + self._thread.join() + try: + del self._threads[self._file_name] + except KeyError: + assert self._file_name == ':memory:' + + @property + def queue_size ( self ): + """Return the queue size.""" + return self._thread._sql_queue.qsize() + + def set_row_factory ( self, row_factory ): + self._thread._sql_queue.put ( Sqlite3WorkerSetRowFactory ( self._thread, row_factory ), timeout=5 ) + + def set_text_factory ( self, text_factory ): + self._thread._sql_queue.put ( Sqlite3WorkerSetTextFactory ( self._thread, text_factory ), timeout=5 ) + + def execute_ex ( self, query, values=None ): + """Execute a query. + Args: + query: The sql string using ? for placeholders of dynamic values. + values: A tuple of values to be replaced into the ? of the query. + Returns: + a tuple of ( rows, description, lastrowid ): + rows is a list of row results returned by fetchall() or [] if no rows + description is the results of cursor.description after executing the query + lastrowid is the result of calling cursor.lastrowid after executing the query + """ + if self._exit_set: + LOGGER.debug ( "Exit set, not running: %s", query ) + raise ProgrammingError ( 'sqlite worker already closed' ) + LOGGER.debug ( "request execute: %s", query ) + r = Sqlite3WorkerExecute ( self._thread, query, values or [] ) + self._thread._sql_queue.put ( r, timeout=5 ) + success, result = r.results.get() + if not success: + raise result + else: + return result + + def execute ( self, query, values=None ): + return self.execute_ex ( query, values )[0] + + def executescript_ex ( self, query ): + if self._exit_set: + LOGGER.debug ( "Exit set, not running: %s", query ) + raise ProgrammingError ( 'sqlite worker already closed' ) + LOGGER.debug ( "request executescript: %s", query ) + r = Sqlite3WorkerExecuteScript ( self._thread, query ) + self._thread._sql_queue.put ( r, timeout=5 ) + success, result = r.results.get() + if not success: + raise result + else: + return result + + def executescript ( self, sql ): + return self.executescript_ex ( sql )[0] + + def commit ( self ): + if self._exit_set: + LOGGER.debug ( "Exit set, not commiting" ) + raise ProgrammingError ( 'sqlite worker already closed' ) + LOGGER.debug ( "request commit" ) + self._thread._sql_queue.put ( Sqlite3WorkerCommit ( self._thread ), timeout=5 ) + + @property + def total_changes ( self ): + if self._exit_set: + LOGGER.debug ( "Exit set, not querying total_changes" ) + raise ProgrammingError ( 'sqlite worker already closed' ) + return self._thread._sqlite3_conn.total_changes + +class Sqlite3worker_dbapi_cursor ( Frozen_object ): + con = None + rows = None + description = None + lastrowid = None + + def __init__ ( self, con ): + self.con = con + + def close ( self ): + pass + + def execute ( self, sql, values=None ): + self.rows, self.description, self.lastrowid = self.con.worker.execute_ex ( sql, values ) + + def executescript ( self, sql_script ): + self.rows, self.description, self.lastrowid = self.con.worker.executescript_ex ( sql_script ) + + def fetchone ( self ): + try: + return self.con.row_factory ( self, self.rows.pop ( 0 ) ) + except IndexError: + return None + + def __iter__ ( self ): + while self.rows: + yield self.fetchone() + +class Sqlite3worker_dbapi_connection ( Frozen_object ): + worker = None + + def __init__ ( self, worker ): + self.worker = worker + + def commit ( self ): + self.worker.commit() + + def cursor ( self ): + return Sqlite3worker_dbapi_cursor ( self ) + + def execute ( self, sql, values=None ): + cur = self.cursor() + cur.execute ( sql, values ) + return cur + + def executescript ( self, sql_script ): + cur = self.cursor() + cur.executescript ( sql_script ) + return cur + + def close ( self ): + self.worker.close() + self.worker = None + + @staticmethod + def row_factory ( cursor, row ): + return row + + @property + def text_factory ( self ): # pragma: no cover + raise NotImplementedError ( type ( self ).__name__ + '.text_factory' ) + + @text_factory.setter + def text_factory ( self, text_factory ): + self.worker.set_text_factory ( text_factory ) -class Sqlite3Worker(threading.Thread): - """Sqlite thread safe object. - - Example: - from sqlite3worker import Sqlite3Worker - sql_worker = Sqlite3Worker("/tmp/test.sqlite") - sql_worker.execute( - "CREATE TABLE tester (timestamp DATETIME, uuid TEXT)") - sql_worker.execute( - "INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow")) - sql_worker.execute( - "INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog")) - sql_worker.execute("SELECT * from tester") - sql_worker.close() - """ - def __init__(self, file_name, max_queue_size=100): - """Automatically starts the thread. - - Args: - file_name: The name of the file. - max_queue_size: The max queries that will be queued. - """ - threading.Thread.__init__(self) - self.daemon = True - self.sqlite3_conn = sqlite3.connect( - file_name, check_same_thread=False, - detect_types=sqlite3.PARSE_DECLTYPES) - self.sqlite3_cursor = self.sqlite3_conn.cursor() - self.sql_queue = Queue.Queue(maxsize=max_queue_size) - self.results = {} - self.max_queue_size = max_queue_size - self.exit_set = False - # Token that is put into queue when close() is called. - self.exit_token = str(uuid.uuid4()) - self.start() - self.thread_running = True - - def run(self): - """Thread loop. - - This is an infinite loop. The iter method calls self.sql_queue.get() - which blocks if there are not values in the queue. As soon as values - are placed into the queue the process will continue. - - If many executes happen at once it will churn through them all before - calling commit() to speed things up by reducing the number of times - commit is called. - """ - LOGGER.debug("run: Thread started") - execute_count = 0 - for token, query, values in iter(self.sql_queue.get, None): - LOGGER.debug("sql_queue: %s", self.sql_queue.qsize()) - if token != self.exit_token: - LOGGER.debug("run: %s", query) - self.run_query(token, query, values) - execute_count += 1 - # Let the executes build up a little before committing to disk - # to speed things up. - if ( - self.sql_queue.empty() or - execute_count == self.max_queue_size): - LOGGER.debug("run: commit") - self.sqlite3_conn.commit() - execute_count = 0 - # Only exit if the queue is empty. Otherwise keep getting - # through the queue until it's empty. - if self.exit_set and self.sql_queue.empty(): - self.sqlite3_conn.commit() - self.sqlite3_conn.close() - self.thread_running = False - return - - def run_query(self, token, query, values): - """Run a query. - - Args: - token: A uuid object of the query you want returned. - query: A sql query with ? placeholders for values. - values: A tuple of values to replace "?" in query. - """ - if query.lower().strip().startswith("select"): - try: - self.sqlite3_cursor.execute(query, values) - self.results[token] = self.sqlite3_cursor.fetchall() - except sqlite3.Error as err: - # Put the error into the output queue since a response - # is required. - self.results[token] = ( - "Query returned error: %s: %s: %s" % (query, values, err)) - LOGGER.error( - "Query returned error: %s: %s: %s", query, values, err) - else: - try: - self.sqlite3_cursor.execute(query, values) - except sqlite3.Error as err: - LOGGER.error( - "Query returned error: %s: %s: %s", query, values, err) - - def close(self): - """Close down the thread and close the sqlite3 database file.""" - self.exit_set = True - self.sql_queue.put((self.exit_token, "", ""), timeout=5) - # Sleep and check that the thread is done before returning. - while self.thread_running: - time.sleep(.01) # Don't kill the CPU waiting. - - @property - def queue_size(self): - """Return the queue size.""" - return self.sql_queue.qsize() - - def query_results(self, token): - """Get the query results for a specific token. - - Args: - token: A uuid object of the query you want returned. - - Returns: - Return the results of the query when it's executed by the thread. - """ - delay = .001 - while True: - if token in self.results: - return_val = self.results[token] - del self.results[token] - return return_val - # Double back on the delay to a max of 8 seconds. This prevents - # a long lived select statement from trashing the CPU with this - # infinite loop as it's waiting for the query results. - LOGGER.debug("Sleeping: %s %s", delay, token) - time.sleep(delay) - if delay < 8: - delay += delay - - def execute(self, query, values=None): - """Execute a query. - - Args: - query: The sql string using ? for placeholders of dynamic values. - values: A tuple of values to be replaced into the ? of the query. - - Returns: - If it's a select query it will return the results of the query. - """ - if self.exit_set: - LOGGER.debug("Exit set, not running: %s", query) - return "Exit Called" - LOGGER.debug("execute: %s", query) - values = values or [] - # A token to track this query with. - token = str(uuid.uuid4()) - # If it's a select we queue it up with a token to mark the results - # into the output queue so we know what results are ours. - if query.lower().strip().startswith("select"): - self.sql_queue.put((token, query, values), timeout=5) - return self.query_results(token) - else: - self.sql_queue.put((token, query, values), timeout=5) +def connect ( file_name ): + return Sqlite3worker_dbapi_connection ( Sqlite3Worker ( file_name ) ) diff --git a/sqlite3worker_test.py b/sqlite3worker_test.py index 9fc90d3..4a85aff 100755 --- a/sqlite3worker_test.py +++ b/sqlite3worker_test.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# -*- coding: utf-8 -*- # Copyright (c) 2014 Palantir Technologies # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -22,46 +23,55 @@ """sqlite3worker test routines.""" __author__ = "Shawn Lee" -__email__ = "shawnl@palantir.com" +__email__ = "dashawn@gmail.com" __license__ = "MIT" +import logging import os +import sys import tempfile +import threading import time +import uuid + import unittest import sqlite3worker +if sys.version_info[0] >= 3: + unicode = str class Sqlite3WorkerTests(unittest.TestCase): # pylint:disable=R0904 """Test out the sqlite3worker library.""" - def setUp(self): # pylint:disable=C0103 - self.tmp_file = tempfile.NamedTemporaryFile( - suffix="pytest", prefix="sqlite").name + + def setUp(self): # pylint:disable=D0102 + self.tmp_file = tempfile.mktemp( + suffix="pytest", prefix="sqlite") self.sqlite3worker = sqlite3worker.Sqlite3Worker(self.tmp_file) # Create sql db. - self.sqlite3worker.execute( + self.sqlite3worker.executescript( # using executescript here for code coverage reasons "CREATE TABLE tester (timestamp DATETIME, uuid TEXT)") - def tearDown(self): # pylint:disable=C0103 - self.sqlite3worker.close() + def tearDown(self): # pylint:disable=D0102 + try: + self.sqlite3worker.close() + except sqlite3worker.ProgrammingError: + pass # the test may have already closed the database os.unlink(self.tmp_file) def test_bad_select(self): """Test a bad select query.""" query = "select THIS IS BAD SQL" - self.assertEqual( - self.sqlite3worker.execute(query), - ( - "Query returned error: select THIS IS BAD SQL: " - "[]: no such column: THIS")) + with self.assertRaises ( sqlite3worker.OperationalError ): + self.sqlite3worker.execute(query) def test_bad_insert(self): """Test a bad insert query.""" query = "insert THIS IS BAD SQL" - self.sqlite3worker.execute(query) + with self.assertRaises ( sqlite3worker.OperationalError ): + self.sqlite3worker.execute(query) # Give it one second to clear the queue. - if self.sqlite3worker.queue_size != 0: + if self.sqlite3worker.queue_size != 0: # pragma: no cover - this never happens any more time.sleep(1) self.assertEqual(self.sqlite3worker.queue_size, 0) self.assertEqual( @@ -77,12 +87,158 @@ def test_valid_insert(self): self.sqlite3worker.execute( "INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog")) # Give it one second to clear the queue. - if self.sqlite3worker.queue_size != 0: + if self.sqlite3worker.queue_size != 0: # pragma: no cover - this never happens any more time.sleep(1) self.assertEqual( self.sqlite3worker.execute("SELECT * from tester"), [("2010-01-01 13:00:00", "bow"), ("2011-02-02 14:14:14", "dog")]) + def test_run_after_close(self): + """Test to make sure all events are cleared after object closed.""" + self.sqlite3worker.close() + with self.assertRaises ( sqlite3worker.ProgrammingError ): + self.sqlite3worker.execute( + "INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow")) + + def test_double_close(self): + """Make sure double closing messages properly.""" + self.sqlite3worker.close() + with self.assertRaises ( sqlite3worker.ProgrammingError ): + self.sqlite3worker.close() + + def test_db_closed_properly(self): + """Make sure sqlite object is properly closed out.""" + self.sqlite3worker.close() + with self.assertRaises ( sqlite3worker.ProgrammingError ): + self.sqlite3worker.total_changes + + def test_many_threads(self): + """Make sure lots of threads work together.""" + class threaded(threading.Thread): + def __init__(self, sqlite_obj): + threading.Thread.__init__(self, name=__name__) + self.sqlite_obj = sqlite_obj + self.daemon = True + self.failed = False + self.completed = False + self.start() + + def run(self): + for _ in range(5): + token = str(uuid.uuid4()) + self.sqlite_obj.execute( + "INSERT into tester values (?, ?)", + ("2010-01-01 13:00:00", token)) + resp = self.sqlite_obj.execute( + "SELECT * from tester where uuid = ?", (token,)) + if resp != [("2010-01-01 13:00:00", token)]: # pragma: no cover ( we don't expect tests to fail ) + self.failed = True + break + self.completed = True + + threads = [] + for _ in range(5): + threads.append(threaded(self.sqlite3worker)) + + for i in range(5): + while not threads[i].completed: + time.sleep(.1) + self.assertEqual(threads[i].failed, False) + threads[i].join() + + def test_many_dbapi_threads ( self ): + """Make sure lots of threads work together with dbapi interface.""" + class threaded ( threading.Thread ): + def __init__ ( self, id, tmp_file ): + threading.Thread.__init__ ( self, name='test {}'.format ( id ) ) + self.tmp_file = tmp_file + self.daemon = True + self.failed = False + self.completed = False + self.start() -if __name__ == "__main__": + def run ( self ): + logging.debug ( 'connecting' ) + con = sqlite3worker.connect ( self.tmp_file ) + for i in range ( 5 ): + logging.debug ( 'creating cursor #{}'.format ( i ) ) + c = con.cursor() + token = str ( uuid.uuid4() ) + logging.debug ( 'cursor #{} inserting token {!r}'.format ( i, token ) ) + c.execute ( + "INSERT into tester values (?, ?)", + ( "2010-01-01 13:00:00", token ) + ) + logging.debug ( 'cursor #{} querying token {!r}'.format ( i, token ) ) + c.execute ( + "SELECT * from tester where uuid = ?", (token,) + ) + resp = c.fetchone() + logging.debug ( 'cursor #{} closing'.format ( i ) ) + c.close() + if resp != ( "2010-01-01 13:00:00", token ): # pragma: no cover ( we don't expect tests to fail ) + logging.debug ( 'cursor #{} invalid resp {!r}'.format ( i, resp ) ) + logging.debug ( repr ( resp ) ) + self.failed = True + break + else: + logging.debug ( 'cursor #{} success'.format ( i ) ) + logging.debug ( 'closing connection' ) + con.close() + self.completed = True + + threads = [] + for id in range ( 5 ): + threads.append ( threaded ( id, self.tmp_file ) ) + + con = sqlite3worker.connect ( self.tmp_file ) + con.executescript ( 'pragma foreign_keys=on;' ) # not using this, put here for code coverage reasons + con.row_factory = sqlite3worker.Row + con.text_factory = unicode + + for i in range ( 5 ): + while not threads[i].completed: + time.sleep ( 0.1 ) + self.assertEqual ( threads[i].failed, False ) + threads[i].join() + + logging.debug ( 'counting results' ) # yes I could do a count(*) here but I'm doing it this way for code coverage reasons + con.commit() + cur = con.execute ( 'select * from tester' ) + count = 0 + for row in cur: + self.assertEqual ( len ( row['uuid'] ), 36 ) + count += 1 + self.assertEqual ( cur.fetchone(), None ) # make sure all rows retrieved + con.close() + self.assertEqual ( count, 25 ) + + def test_coverage ( self ): + """ a bunch of miscellaneous things to get code coverage to 100% """ + class Foo ( sqlite3worker.Frozen_object ): + pass + foo = Foo() + with self.assertRaises ( AttributeError ): + foo.bar = 'bar' + self.sqlite3worker.set_row_factory ( sqlite3worker.Row ) + self.assertEqual ( self.sqlite3worker.total_changes, 0 ) + self.sqlite3worker.set_text_factory ( unicode ) + with self.assertRaises ( sqlite3worker.OperationalError ): + self.sqlite3worker.executescript ( 'THIS IS INTENTIONALLY BAD SQL' ) + + # try to force and catch an assert in the close logic... + del self.sqlite3worker._threads[self.sqlite3worker._file_name] + with self.assertRaises ( AssertionError ): + self.sqlite3worker.close() + + self.assertEqual ( sqlite3worker.normalize_file_name ( ':MEMORY:' ), ':memory:' ) + with self.assertRaises ( sqlite3worker.ProgrammingError ): + self.sqlite3worker.executescript ( 'drop table tester' ) + with self.assertRaises ( sqlite3worker.ProgrammingError ): + self.sqlite3worker.commit() + +if __name__ == "__main__": # pragma: no cover ( only executed when running test directly ) + if False: + import sys + logging.basicConfig ( stream=sys.stdout, level=logging.DEBUG, format='%(asctime)s [%(threadName)s %(levelname)s] %(message)s' ) unittest.main()