Source code for dl_utils.distributed.gather

# -*- coding: utf-8 -*-
# @Time    : 4/20/23
# @Author  : Yaojie Shen
# @Project : Deep-Learning-Utils
# @File    : gather.py

import itertools
from typing import List, Any

import torch.distributed as dist


[docs] def gather_objects(list_object: List[Any]) -> List[Any]: """ gather a list of something from multiple GPU. """ assert type(list_object) == list, "This function only receive a list as input." gathered_objects = [None for _ in range(dist.get_world_size())] dist.all_gather_object(gathered_objects, list_object) return list(itertools.chain(*gathered_objects))
__all__ = ["gather_objects"]