Maxime commited on
Commit
2fe95cd
1 Parent(s): c1382e7

fix distributed devices (#612)

Browse files

* fix distributed devices

* Update distributed.py

* Update distributed.py

Files changed (1) hide show
  1. src/axolotl/utils/distributed.py +12 -4
src/axolotl/utils/distributed.py CHANGED
@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
77
  value_scalar = fn()
78
  if not is_distributed():
79
  return [value_scalar]
80
- value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
 
 
81
 
82
  if not is_main_process():
83
  dist.gather(value_tensor, dst=0)
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
137
  """
138
  if is_main_process():
139
  value_scalar = fn()
140
- value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
 
 
141
  else:
142
- value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
 
 
143
 
144
  # Broadcast the tensor to all processes.
145
  barrier()
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
164
  - A list of computed values from all ranks if on the gathering rank, otherwise None.
165
  """
166
  value_scalar = fn()
167
- value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
 
 
168
 
169
  # Placeholder tensor for gathering results
170
  if is_main_process():
 
77
  value_scalar = fn()
78
  if not is_distributed():
79
  return [value_scalar]
80
+ value_tensor = torch.tensor(
81
+ value_scalar, device=torch.cuda.current_device()
82
+ ).float()
83
 
84
  if not is_main_process():
85
  dist.gather(value_tensor, dst=0)
 
139
  """
140
  if is_main_process():
141
  value_scalar = fn()
142
+ value_tensor = torch.tensor(
143
+ value_scalar, device=torch.cuda.current_device()
144
+ ).float()
145
  else:
146
+ value_tensor = torch.tensor(
147
+ 0.0, device=torch.cuda.current_device()
148
+ ) # Placeholder tensor
149
 
150
  # Broadcast the tensor to all processes.
151
  barrier()
 
170
  - A list of computed values from all ranks if on the gathering rank, otherwise None.
171
  """
172
  value_scalar = fn()
173
+ value_tensor = torch.tensor(
174
+ value_scalar, device=torch.cuda.current_device()
175
+ ).float()
176
 
177
  # Placeholder tensor for gathering results
178
  if is_main_process():