"""Tools for concurrency in observable calculation."""
import os
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Iterable, TypeVar
T = TypeVar('T')
[docs]
def create_executor() -> ThreadPoolExecutor:
"""
Create a ``ThreadPoolExecutor`` with the relevant number of workers.
Returns
-------
ThreadPoolExecutor
A thread pool executor with max_workers=`OMP_NUM_THREADS` or 1 if not set.
"""
# we use a ThreadPoolExecutor as most of the concurrent operations
# involve very large arrays; a ProcessPoolExecutor would create a
# copy of each of these arrays per thread.
num_cores = int(os.environ.get("OMP_NUM_THREADS", 1))
return ThreadPoolExecutor(max_workers=num_cores)
[docs]
def core_batch(generator: Iterable[T]) -> Iterable[tuple[T, ...]]:
"""
Batch generator according to the number of available cores, `OMP_NUM_THREADS`.
Parameters
----------
generator : Iterable[T]
The generator to batch.
Yields
------
tuple[T]
Batches of size `OMP_NUM_THREADS`.
See Also
--------
itertools.batched : Standard implementation from 3.12.
Examples
--------
>>> core_batch(range(10))
on 1 core produces [0], [1], [2], [3], [4], [5], [6], [7], [8], [9]
on 4 cores produces [0, 1, 2, 3], [4, 5, 6, 7], [8, 9].
"""
num_cores = int(os.environ.get("OMP_NUM_THREADS", 1))
iterator = iter(generator)
while batch := tuple(islice(iterator, num_cores)):
yield batch