Source code for dl_utils.distributed.device

# -*- coding: utf-8 -*-
# @Time    : 9/22/25
# @Author  : Yaojie Shen
# @Project : Deep-Learning-Utils
# @File    : device.py

__all__ = [
    "recursive_to",
]

from typing import Any, Union

import torch


[docs] def recursive_to(obj: Any, device: Union[str, torch.device] = None) -> Any: """ Recursively move all torch.Tensor in obj to the given device. Supports: Tensor, list, tuple, dict, set. Leaves other objects intact. Args: obj: The object to move. device: The device to move to. If None, uses the current device if gpu is available, else "cpu". Returns: The object with all torch.Tensor moved to the given device. """ if device is None: device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" if isinstance(obj, torch.Tensor): return obj.to(device) elif isinstance(obj, dict): return {k: recursive_to(v, device) for k, v in obj.items()} elif isinstance(obj, list): return [recursive_to(v, device) for v in obj] elif isinstance(obj, tuple): return tuple(recursive_to(v, device) for v in obj) elif isinstance(obj, set): # sets are unordered; converting back to set might lose type/sort return {recursive_to(v, device) for v in obj} else: return obj