Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] SLURM distributed training in containers where scontrol is not available #1527

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,42 @@ def _init_dist_mpi(backend, **kwargs) -> None:
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
torch_dist.init_process_group(backend=backend, **kwargs)

def _slurm_extract_first_node(slurm_nodelist):
'''
fix needed for containers without scontrol as available executable
offers the same functionality as f'scontrol show hostname {node_list} | head -n1'
returns the first hostname in the nodelist to be used as address
'''
# Regular expression to extract node name and range
pattern = re.compile(r'([a-zA-Z]+)(?:\[([0-9,-]+)\])?')

# Find matches in the SLURM_NODELIST string
matches = pattern.findall(slurm_nodelist)

if not matches:
raise ValueError("Invalid SLURM_NODELIST format")

# Extract node name and ranges
node_prefix, node_ranges = matches[0]

# Split the ranges by commas
ranges = node_ranges.split(',')

# Extract the first number in each range
first_numbers = [r.split('-')[0] for r in ranges]

first_node_name = node_prefix + first_numbers[0]
return first_node_name


def _is_scontrol_available():
try:
subprocess.run(["scontrol", "-h"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
return True
except subprocess.CalledProcessError:
return False
except FileNotFoundError:
return False

def _init_dist_slurm(backend,
port=None,
Expand All @@ -196,8 +232,11 @@ def _init_dist_slurm(backend,
else:
num_gpus = torch.cuda.device_count()
local_rank = proc_id % num_gpus
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
if (_is_scontrol_available()):
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
else:
addr = _slurm_extract_first_node(node_list)
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
Expand Down