Source code for dl_utils.distributed.breakpoint
# -*- coding: utf-8 -*-
# @Time : 7/21/25
# @Author : Yaojie Shen
# @Project : Deep-Learning-Utils
# @File : breakpoint.py
import sys
from .basic import get_local_rank, barrier_if_distributed, get_world_size
def _require_ipython():
"""Try to install IPython if failed to import it."""
try:
import IPython
except ImportError:
import pip
if hasattr(pip, 'main'):
pip.main(['install', "IPython"])
else:
pip._internal.main(['install', "IPython"])
def _my_embed(*, stack_depth=2, header="", compile_flags=None, **kwargs):
"""
This is a modified version of IPython.terminal.embed.embed(), add `stack_depth` to arguments.
"""
# Install IPython if failed to import it.
_require_ipython()
from IPython.core.interactiveshell import InteractiveShell
from IPython.terminal.embed import InteractiveShellEmbed
from IPython.terminal.ipapp import load_default_config
config = kwargs.get('config')
if config is None:
config = load_default_config()
config.InteractiveShellEmbed = config.TerminalInteractiveShell
kwargs["config"] = config
using = kwargs.get("using", "sync")
colors = kwargs.pop("colors", "nocolor")
if using:
kwargs["config"].update(
{
"TerminalInteractiveShell": {
"loop_runner": using,
"colors": colors,
"autoawait": using != "sync",
}
}
)
# save ps1/ps2 if defined
ps1 = None
ps2 = None
try:
ps1 = sys.ps1
ps2 = sys.ps2
except AttributeError:
pass
# save previous instance
saved_shell_instance = InteractiveShell._instance
if saved_shell_instance is not None:
cls = type(saved_shell_instance)
cls.clear_instance()
frame = sys._getframe(1)
shell = InteractiveShellEmbed.instance(_init_location_id='%s:%s' % (
frame.f_code.co_filename, frame.f_lineno), **kwargs)
shell(header=header, stack_depth=stack_depth, compile_flags=compile_flags,
_call_location_id='%s:%s' % (frame.f_code.co_filename, frame.f_lineno))
InteractiveShellEmbed.clear_instance()
# restore previous instance
if saved_shell_instance is not None:
cls = type(saved_shell_instance)
cls.clear_instance()
for subclass in cls._walk_mro():
subclass._instance = saved_shell_instance
if ps1 is not None:
sys.ps1 = ps1
sys.ps2 = ps2
[docs]
def dist_breakpoint(rank: int = 0):
"""
Breakpoint for distributed training.
Enter the breakpoint only if the current rank is `rank`, and block all other processes using distributed barrier.
"""
assert 0 <= rank < get_world_size(), f"Invalid rank {rank}, world size: {get_world_size()}."
if get_local_rank() == rank:
_my_embed(stack_depth=3)
barrier_if_distributed()
__all__ = ["dist_breakpoint"]