import abc
import threading
import logging
import queue
import time
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
[docs]class QueryRunner(abc.ABC):
"""Run a query in the backround using the provided executor."""
def __init__(self, deserializer=None):
super().__init__()
self._status_lock = threading.RLock()
self._closed = False
self._started = False
self._done = False
self.timestamp = None
self._result_queue = None
self._future = None
self.study_id = None
if deserializer is not None:
self.deserializer = deserializer
else:
self.deserializer = lambda v: v
[docs] def adapt(self, adapter_func):
func = self.deserializer
self.deserializer = lambda v: adapter_func(func(v))
[docs] def started(self):
with self._status_lock:
return self._started
[docs] def start(self, executor):
with self._status_lock:
assert self._result_queue is not None
self._future = executor.submit(self.run)
self._started = True
self.timestamp = time.time()
[docs] def closed(self):
with self._status_lock:
return self._closed
[docs] def close(self):
elapsed = time.time() - self.timestamp
logger.debug("closing runner after %0.3f", elapsed)
with self._status_lock:
if self._started:
self._future.cancel()
self._closed = True
[docs] def done(self):
with self._status_lock:
return self._done
[docs] @abc.abstractmethod
def run(self):
pass
def _set_future(self, future):
self._future = future
def _set_result_queue(self, result_queue):
self._result_queue = result_queue
[docs]class QueryResult:
"""Run a list of queries in the background.
The result of the queries is enqueued on result_queue
"""
def __init__(self, runners: list[QueryRunner], limit=-1):
self.result_queue: queue.Queue = queue.Queue(maxsize=1_000)
if limit is None:
limit = -1
self.limit = limit
self._counter = 0
self.timestamp = time.time()
self.runners = runners
for runner in self.runners:
assert runner._result_queue is None
runner._set_result_queue(self.result_queue)
self.executor = ThreadPoolExecutor(max_workers=len(runners))
[docs] def done(self):
if self.limit >= 0 and self._counter >= self.limit:
logger.debug("limit done %d >= %d", self._counter, self.limit)
return True
if all(r.done() for r in self.runners):
return True
return False
def __iter__(self):
return self
def __next__(self):
while True:
try:
item = self.result_queue.get(timeout=0.1)
self._counter += 1
return item
except queue.Empty as exp:
if not self.done():
return None
logger.debug("result done")
raise StopIteration() from exp
[docs] def get(self, timeout=0):
"""Pop the next entry from the queue.
Return None if the queue is still empty after timeout seconds.
"""
try:
row = self.result_queue.get(timeout=timeout)
return row
except queue.Empty as exp:
if self.done():
raise StopIteration() from exp
return None
[docs] def start(self):
self.timestamp = time.time()
for runner in self.runners:
runner.start(self.executor)
time.sleep(0.1)
[docs] def close(self):
"""Gracefully close and dispose of resources."""
for runner in self.runners:
try:
runner.close()
except Exception as ex: # pylint: disable=broad-except
logger.info(
"exception in result close: %s", type(ex), exc_info=True)
while not self.result_queue.empty():
self.result_queue.get()
logger.debug("closing thread pool executor")
self.executor.shutdown(wait=True)
elapsed = time.time() - self.timestamp
logger.debug("result closed after %0.3f", elapsed)