Source code for adaptive_executor.executor

"""Adaptive executor implementation with dynamic worker scaling."""

import queue
import signal
import threading
import time
from typing import Callable

from .utils import get_logger

logger = get_logger(__name__)


[docs] class AdaptiveExecutor:
[docs] def __init__(self, max_workers, policy, check_interval=60): self.max_workers = max_workers self.policy = policy self.check_interval = check_interval self.tasks = queue.Queue() self.shutdown_flag = False self.permits = threading.Semaphore(0) self.current_limit = 0 self._set_limit(self.policy.target_workers()) for _ in range(max_workers): threading.Thread(target=self._worker, daemon=True).start() threading.Thread(target=self._controller, daemon=True).start() self._register_signal_handlers()
def _register_signal_handlers(self): def handler(signum, frame): signame = ( signal.Signals(signum).name if hasattr(signal, "Signals") else str(signum) ) logger.info("Received signal %s, shutting down...", signame) self.shutdown() # Register SIGINT (available on all platforms) try: signal.signal(signal.SIGINT, handler) logger.debug("Registered signal handler for SIGINT") except (ValueError, AttributeError) as e: logger.warning("Failed to register signal handler for SIGINT: %s", e) # Register SIGTERM (not available on Windows) if hasattr(signal, "SIGTERM"): try: signal.signal(signal.SIGTERM, handler) logger.debug("Registered signal handler for SIGTERM") except (ValueError, AttributeError) as e: logger.warning("Failed to register signal handler for SIGTERM: %s", e) def _set_limit(self, new_limit): new_limit = min(new_limit, self.max_workers) diff = new_limit - self.current_limit if diff > 0: for _ in range(diff): self.permits.release() elif diff < 0: for _ in range(-diff): self.permits.acquire() old_limit = self.current_limit self.current_limit = new_limit if old_limit != new_limit: logger.info( "Adjusted worker concurrency: %d -> %d (max: %d)", old_limit, new_limit, self.max_workers, ) def _controller(self): while not self.shutdown_flag: target = self.policy.target_workers() if target != self.current_limit: self._set_limit(target) time.sleep(self.check_interval) def _worker(self): thread_name = threading.current_thread().name logger.debug("Worker %s started", thread_name) while not self.shutdown_flag: try: fn, args, kwargs = self.tasks.get(timeout=1) task_name = fn.__name__ if hasattr(fn, "__name__") else "anonymous" logger.debug("Worker %s starting task: %s", thread_name, task_name) try: # Acquire a permit so active task concurrency follows current_limit. self.permits.acquire() start_time = time.monotonic() fn(*args, **kwargs) duration = time.monotonic() - start_time logger.debug( "Worker %s completed task %s in %.3f seconds", thread_name, task_name, duration, ) except Exception as e: logger.error( "Error in worker %s while executing task %s: %s", thread_name, task_name, str(e), exc_info=True, ) raise finally: self.permits.release() self.tasks.task_done() except queue.Empty: continue logger.debug("Worker %s shutting down", thread_name)
[docs] def submit(self, fn: Callable, *args, **kwargs) -> None: """Submit a task to be executed by the worker pool. Args: fn: The function to execute *args: Positional arguments to pass to the function **kwargs: Keyword arguments to pass to the function """ task_name = fn.__name__ if hasattr(fn, "__name__") else "anonymous" logger.debug("Submitting task: %s", task_name) self.tasks.put((fn, args, kwargs)) logger.debug( "Task %s submitted to queue (queue size: %d)", task_name, self.tasks.qsize() )
[docs] def join(self, timeout: float = None) -> bool: """Wait until all tasks in the queue are processed. Args: timeout: Maximum time to wait in seconds Returns: bool: True if all tasks completed, False if timed out """ logger.info("Waiting for all tasks to complete...") try: if timeout is not None: # Use unfinished_tasks so we wait for execution completion, # not just for the queue to be dequeued. end_time = time.monotonic() + timeout while self.tasks.unfinished_tasks > 0 and time.monotonic() < end_time: time.sleep(0.1) return self.tasks.unfinished_tasks == 0 else: self.tasks.join() return True except KeyboardInterrupt: logger.warning("Join interrupted by user") return False
[docs] def shutdown(self) -> None: """Shut down the executor and all worker threads.""" if self.shutdown_flag: return logger.info("Shutting down executor...") self.shutdown_flag = True # Clear any pending tasks while not self.tasks.empty(): try: self.tasks.get_nowait() self.tasks.task_done() except queue.Empty: break logger.debug("Executor shutdown complete")