65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
|
|
"""Thread execution pool."""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||
|
|
from typing import TYPE_CHECKING, Any, Callable
|
||
|
|
|
||
|
|
from .base import BasePool, apply_target
|
||
|
|
|
||
|
|
__all__ = ('TaskPool',)
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from typing import TypedDict
|
||
|
|
|
||
|
|
PoolInfo = TypedDict('PoolInfo', {'max-concurrency': int, 'threads': int})
|
||
|
|
|
||
|
|
# `TargetFunction` should be a Protocol that represents fast_trace_task and
|
||
|
|
# trace_task_ret.
|
||
|
|
TargetFunction = Callable[..., Any]
|
||
|
|
|
||
|
|
|
||
|
|
class ApplyResult:
|
||
|
|
def __init__(self, future: Future) -> None:
|
||
|
|
self.f = future
|
||
|
|
self.get = self.f.result
|
||
|
|
|
||
|
|
def wait(self, timeout: float | None = None) -> None:
|
||
|
|
wait([self.f], timeout)
|
||
|
|
|
||
|
|
|
||
|
|
class TaskPool(BasePool):
|
||
|
|
"""Thread Task Pool."""
|
||
|
|
limit: int
|
||
|
|
|
||
|
|
body_can_be_buffer = True
|
||
|
|
signal_safe = False
|
||
|
|
|
||
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
self.executor = ThreadPoolExecutor(max_workers=self.limit)
|
||
|
|
|
||
|
|
def on_stop(self) -> None:
|
||
|
|
self.executor.shutdown()
|
||
|
|
super().on_stop()
|
||
|
|
|
||
|
|
def on_apply(
|
||
|
|
self,
|
||
|
|
target: TargetFunction,
|
||
|
|
args: tuple[Any, ...] | None = None,
|
||
|
|
kwargs: dict[str, Any] | None = None,
|
||
|
|
callback: Callable[..., Any] | None = None,
|
||
|
|
accept_callback: Callable[..., Any] | None = None,
|
||
|
|
**_: Any
|
||
|
|
) -> ApplyResult:
|
||
|
|
f = self.executor.submit(apply_target, target, args, kwargs,
|
||
|
|
callback, accept_callback)
|
||
|
|
return ApplyResult(f)
|
||
|
|
|
||
|
|
def _get_info(self) -> PoolInfo:
|
||
|
|
info = super()._get_info()
|
||
|
|
info.update({
|
||
|
|
'max-concurrency': self.limit,
|
||
|
|
'threads': len(self.executor._threads)
|
||
|
|
})
|
||
|
|
return info
|