侧边栏壁纸
博主头像
H 的博客

行动起来,活在当下

  • 累计撰写 2 篇文章
  • 累计创建 1 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

mpi4py 包装器

H
H
2025-02-22 / 0 评论 / 0 点赞 / 31 阅读 / 0 字

通过并行化可以有效地提高 Python 程序的效率。

在 Python 中,有两种主要的并行化方法:多进程(multiprocessing)和 mpi4py。

多进程通过进程池的概念方便地分配任务,但只能在单台机器上运行。

Mpi4py 遵循 MPI(消息传递接口)标准, 允许在多个节点服务器上运行。

本文介绍了 mpi4py 包装器,旨在简化其使用。

均匀分配任务

为了简单起见,假设我们需要执行多个独立的任务。

我们将这些任务均匀地分配给进程。

def split_task(N):
    """
    Split tasks for MPI
    """
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()
    if rank <= N % size:
        start = (N // size + 1) * rank
    else:
        start = rank * (N // size) + (N % size)
    if rank + 1 <= N % size:
        end = (N // size + 1) * (rank + 1)
    else:
        end = (rank + 1) * (N // size) + (N % size)
    return start, end


def MPI_run_tasks_equal_distribution(func, args, show_progress=False):
    """
    Run tasks in MPI
    """
    startTime = time.time()
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()
    Ntask = len(args)
    # dirupt the order of tasks
    if rank == 0:
        index = np.arange(Ntask)
        np.random.shuffle(index)
    else:
        index = None
    # broadcast the index
    index = comm.bcast(index, root=0)
    args = [args[i] for i in index]
    # Equal distribution of tasks
    start, end = split_task(Ntask)
    results = []
    for i in range(start, end):
        if not isinstance(args[i], tuple):
            result = func(args[i])
        else:
            result = func(*args[i])
        if not isinstance(result, tuple):
            result = (result,)
        if show_progress and rank == 0:
            currentTime = time.time()
            elapsedTime = currentTime - startTime
            remainingTime = elapsedTime / (i + 1 - start) * (end - i - 1)
            elapsedTime = "%d:%02d:%02d" % (
                elapsedTime // 3600,
                elapsedTime % 3600 // 60,
                elapsedTime % 60,
            )
            remainingTime = "%d:%02d:%02d" % (
                remainingTime // 3600,
                remainingTime % 3600 // 60,
                remainingTime % 60,
            )
            logging.info(
                "Root progress %.2f%%, elapsed time %s, remaining time %s"
                % ((i + 1 - start) / (end - start) * 100, elapsedTime, remainingTime)
            )
        results.append(result)
    results_list = comm.gather(results, root=0)
    if rank == 0:
        results = []
        for i in range(size):
            results.extend(results_list[i])
        # restore the order of results
        results = [results[i] for i in np.argsort(index)]
        results = [list(row) for row in zip(*results)]
    results = comm.bcast(results, root=0)
    return results

split_task 函数将任务均匀地分配给进程。 MPI_run_tasks_equal_distribution 函数包装了 MPI 进程。 下面是使用包装器的示例。

def single_task(i):
    return i**2

args = list(range(10))
results = MPI_run_tasks_equal_distribution(single_task, args)
if MPI.COMM_WORLD.Get_rank() == 0:
    logging.info("Results: %s" % results)

要运行代码,请使用以下命令:

mpiexec -n 4 python main.py

结果如下:

[[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]]

动态分配任务

由于所有任务都是独立的, 如果所有任务完成所需的时间相同,MPI_run_tasks_equal_distribution 函数应该能实现线性加速。

然而,如果每个任务所需的时间不同, 更好的方法是动态地将任务分配给进程。

这可以通过使用 MPI_run_tasks_root_distribution 函数来实现。

def MPI_run_tasks_root_distribution(func, args, show_progress=False):
    """
    Run tasks in MPI where the root process distributes tasks to worker processes.
    """
    startTime = time.time()
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()
    Ntask = len(args)
    results = [None] * Ntask
    if rank == 0:
        status = MPI.Status()
        send_count = 0
        recv_count = 0
        # send initial tasks
        for i in range(1, size):
            comm.send(i - 1, dest=i, tag=i - 1)
            send_count += 1
        # receive results and send new tasks
        while recv_count < Ntask:
            result = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status)
            if not isinstance(result, tuple):
                result = (result,)
            results[status.Get_tag()] = result
            recv_count += 1
            if show_progress and status.Get_source() == 1:
                currentTime = time.time()
                elapsedTime = currentTime - startTime
                remainingTime = elapsedTime / recv_count * (Ntask - recv_count)
                elapsedTime = "%d:%02d:%02d" % (
                    elapsedTime // 3600,
                    elapsedTime % 3600 // 60,
                    elapsedTime % 60,
                )
                remainingTime = "%d:%02d:%02d" % (
                    remainingTime // 3600,
                    remainingTime % 3600 // 60,
                    remainingTime % 60,
                )
                logging.info(
                    "Root progress %.2f%%, elapsed time %s, remaining time %s"
                    % (recv_count / Ntask * 100, elapsedTime, remainingTime)
                )
            if send_count < Ntask:
                comm.send(send_count, dest=status.Get_source(), tag=send_count)
                send_count += 1
        results = [list(row) for row in zip(*results)]
        # send stop signal
        for i in range(1, size):
            comm.send(None, dest=i, tag=Ntask)
    else:
        while True:
            status = MPI.Status()
            # receive tasks
            task = comm.recv(source=0, tag=MPI.ANY_TAG, status=status)
            if status.Get_tag() == Ntask:
                break
            # run tasks
            task = args[status.Get_tag()]
            if not isinstance(task, tuple):
                result = func(task)
            else:
                result = func(*task)
            # send results
            comm.send(result, dest=0, tag=status.Get_tag())
    results = comm.bcast(results, root=0)
    return results

下面是使用 MPI_run_tasks_root_distribution 函数的示例。

def single_task_1(i, j):
    return i**2, j**2, i + j
args = [(i, i + 1) for i in range(10)]
results = MPI_run_tasks_root_distribution(single_task_1, args)
if MPI.COMM_WORLD.Get_rank() == 0:
    print(results)

结果如下:

[[0, 1, 4, 9, 16, 25, 36, 49, 64, 81], [1, 4, 9, 16, 25, 36, 49, 64, 81, 100], [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]]

通过提供的这两个函数,Python 程序可以轻松实现并行化。

0

评论区