Asynchronous Mapping with asyncio and Python Coroutines

Asynchronous programming in Python allows for concurrent execution of code, greatly improving the efficiency of tasks like IO-bound operations, networking requests, and other operations that involve waiting for external resources.

The map function in Python provides a neat interface to apply a function over a list/generator: map(myfunction, mylist). Multiprocessing pools provide a setup for parallelizing this across processes or threads to achieve a performance boost. This can be done using the following code:

from multiprocessing import Pool

results = []
with Pool(num_concurrent_processes) as pool:
    for res in pool.imap(myfunction, mylist):
        results.append(res)
pool.join()

Alternatively, one can use multithreading instead of multiprocessing as follows:

from multiprocessing.pool import ThreadPool

results = []
with ThreadPool(num_concurrent_processes) as pool:
    for res in pool.imap(myfunction, mylist):
        results.append(res)
pool.join()

The pools have an advantage that the maximum number of parallel processes can be specified, which limits overwhelming access to too many resources.

Suppose we want to use this setup but with an async coroutine instead of a normal function. Using asyncio.gather(*[mycoroutine(x) for x in mylist]) might seem like the correct idea, but this will end up consuming too many resources because all of these will run in parallel.

For example, if the coroutine performs a remote web request on a link, and the list has tens of thousands of links, it is highly likely the remote server will block the IP of the requesting machine, as it perceives tens of thousands of parallel requests, which looks like an attack.

The Proposed Solution

We can attempt to formulate a solution for this using the Queue and the workers' workflow, similar to this from the asyncio library. However, there are more considerations that need to be taken care of. We can start by using a worker model where a set number of worker coroutines pull tasks from a queue. This allows us to regulate the number of concurrent requests. The workers will continue pulling tasks from the queue until all tasks are complete.

Let's look at the Python code for this solution:

import asyncio
from typing import Callable, Coroutine, Iterable, List, Optional


async def aworker(
    coroutine: Coroutine,
    tasks_queue: asyncio.Queue,
    result_queue: asyncio.Queue,
    stop_event: asyncio.Event,
    timeout: float = 1,
    callback: Optional[Callable] = None
) -> None:
    """
    A worker coroutine to process tasks from a queue.

    Args:
        coroutine: The coroutine to be applied to each task.
        tasks_queue: The queue containing the tasks to be processed.
        result_queue: The queue to put the results of each processed task.
        stop_event: An event to signal when all tasks have been added to the tasks queue.
        timeout: The timeout value for getting a task from the tasks queue.
        callback: A function that can be called at the end of each coroutine.
    """
    # Continue looping until stop_event is set and the tasks queue is empty
    while not stop_event.is_set() or not tasks_queue.empty():
        try:
            # Try to get a task from the tasks queue with a timeout
            idx, arg = await asyncio.wait_for(tasks_queue.get(), timeout)
        except asyncio.TimeoutError:
            # If no task is available, continue the loop
            continue
        try:
            # Try to execute the coroutine with the argument from the task
            result = await coroutine(arg)
            # If successful, add the result to the result queue
            result_queue.put_nowait((idx, result))

        finally:
            # Mark the task as done in the tasks queue
            tasks_queue.task_done()
            # callback for progress update
            if callback is not None:
                callback(idx, arg)


async def amap(
    coroutine: Coroutine,
    data: Iterable,
    max_concurrent_tasks: int = 10,
    max_queue_size: int = -1,  # infinite
    callback: Optional[Callable] = None,
) -> List:
    """
    An async function to map a coroutine over a list of arguments.

    Args:
        coroutine: The coroutine to be applied to each argument.
        data: The list of arguments to be passed to the coroutine.
        max_concurrent_tasks: The maximum number of concurrent tasks.
        max_queue_size: The maximum number of tasks in the workers queue.
        callback: A function to be called at the end of each coroutine.
    """
    # Initialize the tasks queue and results queue
    # The queue size is infinite if max_queue_size is 0 or less.
    # Setting it to finite number will save some resources,
    # but will risk that an exception will be thrown too late.
    # Should be higher than the max_concurrent_tasks.
    tasks_queue = asyncio.Queue(max_queue_size)
    result_queue = asyncio.PriorityQueue()

    # Create an event to signal when all tasks have been added to the tasks queue
    stop_event = asyncio.Event()
    # Create workers
    workers = [
        asyncio.create_task(aworker(
            coroutine, tasks_queue, result_queue, stop_event, callback=callback
        ))
        for _ in range(max_concurrent_tasks)
    ]

    # Add inputs to the tasks queue
    for arg in enumerate(data):
        await tasks_queue.put(arg)
    # Set the stop_event to signal that all tasks have been added to the tasks queue
    stop_event.set()

    # Wait for all workers to complete
    # raise the earliest exception raised by a coroutine (if any)
    await asyncio.gather(*workers)
    # Ensure all tasks have been processed
    await tasks_queue.join()

    # Gather all results
    results = []
    while not result_queue.empty():
        # Get the result from the results queue and discard the index
        # Given that the results queue is a PriorityQueue, the index
        # plays a role to ensure that the results are in the same order
        # like the original list.
        _, res = result_queue.get_nowait()
        results.append(res)
    return results

There are some loopholes that are handled here:

  1. An optional callback argument can be added to invoke after each execution of the coroutine. This is helpful to track progress of the amap operation. For example, by using a tqdm progress bar (example below).
  2. The results are collected in a PriorityQueue that sorts by the index of appearing in the input list. This is to ensure that the output will be returned in the same order as the input.
  3. Exceptions are handled properly, that is, if a coroutine raises an exception, we catch it and mark the task as done in the queue before we throw the exception. Otherwise, await work_queue.join() might hang since it doesn't recognize that some tasks are consumed.
  4. A stop event is included to signal for the workers to terminate once there are no tasks to consume.

The typical setup of using a while True in the worker is problematic because we can't use asyncio.gather(*workers) with it since the workers will run forever. In that case, if an exception is raised inside a worker without proper marking that the task is done, the await work_queue.join() will hang forever since it is waiting for a task that it thinks is taking too long. Hence, all of these need to be handled together for proper termination. The While True statement in typical implementations is useful when we want a consumer that can wait for tasks upon request, while the request might take too long to be made. However, in this setup, we have a condition indicating when the workers can terminate.

An example of using this, while comparing to other alternatives:

import multiprocessing
import random
import time
from tqdm.asyncio import tqdm as atqdm
from tqdm import tqdm


async def mycoroutine(t):
    await asyncio.sleep(t)
    return t + 1


def myfunction(t):
    time.sleep(t)
    return t + 1


async def amap_all(coroutine, data):
    pbar = atqdm(total=len(data))  # track progress tqdm

    def callback(*_):
        pbar.update()

    res = await amap(coroutine, data, 10, callback=callback)
    pbar.close()
    return res


def map_pool(function, data):
    with multiprocessing.Pool(10) as pool:
        res = list(tqdm(pool.imap(myfunction, data), total=len(data)))  # track progress tqdm
    pool.join()
    return res


def map_thread_pool(function, data):
    with multiprocessing.pool.ThreadPool(10) as pool:
        res = list(tqdm(pool.imap(myfunction, data), total=len(data)))  # track progress tqdm
    pool.join()
    return res


def map_sequential(function, data):
    return list(tqdm(map(myfunction, data), total=len(data)))  # track progress tqdm


if __name__ == "__main__":
    random.seed(41)
    inputs = [random.random() for _ in range(100)]

    print(inputs[:4])

    tic = time.time()
    print(map_sequential(myfunction, inputs)[:4])
    toc = time.time()
    print("Map", toc - tic, '\n')

    tic = time.time()
    print(map_pool(myfunction, inputs)[:4])
    toc = time.time()
    print("Multiprocessing pool", toc - tic, '\n')

    tic = time.time()
    print(map_pool(myfunction, inputs)[:4])
    toc = time.time()
    print("Multithreading pool", toc - tic, '\n')

Which will output:

[0.38102068999577143, 0.23071918631047517, 0.1660367243141352, 0.913833434772713]
100%|█████████████████████████████████████| 100/100 [00:51<00:00,  1.96it/s]
[1.3810206899957715, 1.2307191863104752, 1.1660367243141352, 1.9138334347727128]
Map 51.13913917541504

100%|█████████████████████████████████████| 100/100 [00:05<00:00, 18.08it/s]
[1.3810206899957715, 1.2307191863104752, 1.1660367243141352, 1.9138334347727128]
Multiprocessing pool 5.5755650997161865

100%|█████████████████████████████████████| 100/100 [00:05<00:00, 18.07it/s]
[1.3810206899957715, 1.2307191863104752, 1.1660367243141352, 1.9138334347727128]
Multithreading pool 5.573708772659302

100%|█████████████████████████████████████|100/100 [00:05<00:00, 18.67it/s]
[1.3810206899957715, 1.2307191863104752, 1.1660367243141352, 1.9138334347727128]
Async map 5.358321905136108

As seen, the introduced async map is performing the best. The difference might seem minor, but it should scale better in larger tasks.

Summary

The amap function is an asynchronous alternative to the map function, designed to work with asyncio coroutines. It takes in a coroutine function, a list of arguments, a maximum number of concurrent tasks, and a callback.

The function operates by creating a queue of tasks, each task consisting of an argument index and its corresponding value. A priority queue is also established to ensure that results are returned in the order the tasks were initiated. A specific number of worker tasks is generated to pull tasks from the queue, apply the coroutine function to them, and store the results in the result queue. Workers continue to process tasks from the queue until all tasks are finished. Afterwards, the function waits for all worker tasks to complete, confirms that all tasks have been processed, and returns the results in their original order.

This implementation allows for efficient mapping of a coroutine over a list of arguments while respecting the specified limit on concurrent tasks. Additionally, it handles exceptions in worker tasks appropriately and utilizes asyncio and Python coroutines to perform tasks concurrently.


Created on 2023-07-29 at 22:08