在accelerate框架中精确恢复dataloader状态的方法
各类方案的比对分析
不难发现,在使用accelerate框架时,accelerator.load_state只能可靠地恢复model、optimizer、scheduler的状态,而dataloader的状态恢复,里面有许多坑。
从原理上来讲,在accelerate框架里,对于一个启用了shuffle的dataloaer,如果要想精确恢复:
- 要么完全和model、optimizer、scheduler相同,可以保存/加载内部状态 (stateful dataloader)
- 要么dataloader初始化时从全局随机环境提取一个种子存入内部状态,此后正常进行iteration时不再与全局随机环境发生关联,只依赖内部状态 (seedable sampler)
简单来讲,要么状态可以瞬间重置到正确的位置,要么在快进时不对外输入输出任何影响。
一个典型的反面例子是普通dataloaer+手动skip(纯手写或用skip_first_batches来skip)。普通dataloader在iteration的时候是依赖全局随机环境的,而skip开始前的全局随机环境没有被恢复,skip过程中的全局随机环境变化也与上次运行时的实际情况不符,除非再额外引入一些更加繁琐的机制加以复原,否则这种方法不可能精确恢复dataloader的状态。
accelerate官方文档里主要提到了两类解法:
- Stateful Dataloader: accelerate虽然声称支持新版torchdata里的StatefulDataloader,但是issue里有大量反馈存在各种兼容性bug。(我自己也碰到过这种兼容性bug,虽然甚至连怎么不兼容都无法准确复现)
- skip_first_batches:
accelerator.skip_first_batches方案看起来稍微好一点,但是实质上和纯手动skip相同,因此同样无法解决随机环境无法复现的问题。
在使用deepwiki查看源码,以及浏览了issue https://github.com/huggingface/accelerate/issues/3242 之后,我发现use_seedable_sampler是一个比较好的方案。虽然 https://github.com/huggingface/accelerate/issues/3398 声称这个也有bug,不过我在v1.11.0里面没有复现,可能已被某个PR修复。(并且这个方案的原理相当简单可靠,没有过于明显的bug之后应该不至于撞上什么特例)
以下是基于accelerate==1.11.0做的验证实验
测试
Test 1
对于程序:
import torch
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os
# from torchdata.stateful_dataloader import StatefulDataLoader
# Simple dataset with 10 elements
class SimpleDataset(Dataset):
def __init__(self):
self.data = list(range(10))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):
output = f"\n--- Process {accelerator.process_index} ---\n"
output += f"Epoch {epoch + 1}:\n"
data = []
for i, batch in enumerate(dataloader):
data.append(batch.tolist())
if interrupt and epoch == 1 and i == 1:
accelerator.save_state(output_dir="debug_data_order_checkpoints")
output += f"Random state saved\n"
output += f"{data}\n"
return output
if __name__ == "__main__":
accelerator = Accelerator()
set_seed(42)
# Create the dataset and DataLoader with shuffle=True
dataset = SimpleDataset()
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Check data order for 3 epochs
all_outputs = []
for epoch in range(3):
epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
accelerator.wait_for_everyone()
# Resume from checkpoint
accelerator = Accelerator(
# dataloader_config=dataloader_config,
)
set_seed(42)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Load random state
if os.path.exists("debug_data_order_checkpoints"):
accelerator.load_state("debug_data_order_checkpoints")
print("Random state loaded from debug_data_order_checkpoints")
# skip_dataloader = accelerator.skip_first_batches(dataloader, 2)
# Check data order for 1 epoch
all_outputs = []
for epoch in range(1):
epoch_output = print_epoch_batches(epoch, dataloader)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
运行accelerate launch --multi_gpu --num_processes 2 a.py,输出:
--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]
--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]
--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]
--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]
--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]
--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]
Random state loaded from debug_data_order_checkpointsRandom state loaded from debug_data_order_checkpoints
--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]
--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]
可以看出load state不会影响普通dataloader的状态
Test 2
对于程序:
import torch
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os
# from torchdata.stateful_dataloader import StatefulDataLoader
# Simple dataset with 10 elements
class SimpleDataset(Dataset):
def __init__(self):
self.data = list(range(10))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):
output = f"\n--- Process {accelerator.process_index} ---\n"
output += f"Epoch {epoch + 1}:\n"
data = []
for i, batch in enumerate(dataloader):
data.append(batch.tolist())
if interrupt and epoch == 1 and i == 1:
accelerator.save_state(output_dir="debug_data_order_checkpoints")
output += f"Random state saved\n"
output += f"{data}\n"
return output
if __name__ == "__main__":
accelerator = Accelerator()
set_seed(42)
# Create the dataset and DataLoader with shuffle=True
dataset = SimpleDataset()
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Check data order for 3 epochs
all_outputs = []
for epoch in range(3):
epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
accelerator.wait_for_everyone()
# Resume from checkpoint
accelerator = Accelerator(
# dataloader_config=dataloader_config,
)
set_seed(42)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Load random state
if os.path.exists("debug_data_order_checkpoints"):
accelerator.load_state("debug_data_order_checkpoints")
print("Random state loaded from debug_data_order_checkpoints")
# skip_dataloader = accelerator.skip_first_batches(dataloader, 2)
# Check data order for 1 epoch
all_outputs = []
for epoch in range(1,3):
current_dataloader = dataloader
if epoch == 1:
# skip first 2 batches in epoch 2
current_dataloader = accelerator.skip_first_batches(dataloader, 2)
epoch_output = print_epoch_batches(epoch, current_dataloader)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
运行accelerate launch --multi_gpu --num_processes 2 a.py,输出
--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]
--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]
--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]
--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]
--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]
--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]
Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints
--- Process 1 ---
Epoch 2:
[[7], [3], [6]]
--- Process 1 ---
Epoch 3:
[[2], [8], [4], [9], [0]]
--- Process 0 ---
Epoch 2:
[[1], [0], [5]]
--- Process 0 ---
Epoch 3:
[[6], [3], [5], [7], [1]]
可以看出普通DataLoader+skip_first_batches不可能正确恢复epoch不是第一个的情况
Test 3
对于程序:
import torch
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os
# from torchdata.stateful_dataloader import StatefulDataLoader
# Simple dataset with 10 elements
class SimpleDataset(Dataset):
def __init__(self):
self.data = list(range(10))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):
output = f"\n--- Process {accelerator.process_index} ---\n"
output += f"Epoch {epoch + 1}:\n"
data = []
for i, batch in enumerate(dataloader):
data.append(batch.tolist())
if interrupt and epoch == 1 and i == 1:
accelerator.save_state(output_dir="debug_data_order_checkpoints")
output += f"Random state saved\n"
output += f"{data}\n"
return output
if __name__ == "__main__":
dataloader_config = DataLoaderConfiguration(
use_stateful_dataloader=True
)
accelerator = Accelerator(
dataloader_config=dataloader_config,
)
set_seed(42)
# Create the dataset and DataLoader with shuffle=True
dataset = SimpleDataset()
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Check data order for 3 epochs
all_outputs = []
for epoch in range(3):
epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
accelerator.wait_for_everyone()
# Resume from checkpoint
accelerator = Accelerator(
dataloader_config=dataloader_config,
)
set_seed(42)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Load random state
if os.path.exists("debug_data_order_checkpoints"):
accelerator.load_state("debug_data_order_checkpoints")
print("Random state loaded from debug_data_order_checkpoints")
# skip_dataloader = accelerator.skip_first_batches(dataloader, 2)
# Check data order for 1 epoch
all_outputs = []
for epoch in range(1,3):
epoch_output = print_epoch_batches(epoch, dataloader)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
运行accelerate launch --multi_gpu --num_processes 2 a.py,输出
--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]
--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]
--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]
--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]
--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]
--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]
Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints
--- Process 0 ---
Epoch 2:
[[8], [1], [0], [5]]
--- Process 0 ---
Epoch 3:
[[6], [3], [5], [7], [1]]
--- Process 1 ---
Epoch 2:
[[9], [7], [3], [6]]
--- Process 1 ---
Epoch 3:
[[2], [8], [4], [9], [0]]
可以看出StatefulDataloader即使不出现兼容性问题,也无法正确恢复epoch信息,因为此时的dataloader是依赖于全局随机环境的。
Test 4
对于程序:
import torch
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os
# from torchdata.stateful_dataloader import StatefulDataLoader
# Simple dataset with 10 elements
class SimpleDataset(Dataset):
def __init__(self):
self.data = list(range(10))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):
output = f"\n--- Process {accelerator.process_index} ---\n"
output += f"Epoch {epoch + 1}:\n"
data = []
for i, batch in enumerate(dataloader):
data.append(batch.tolist())
if interrupt and epoch == 1 and i == 1:
accelerator.save_state(output_dir="debug_data_order_checkpoints")
output += f"Random state saved\n"
output += f"{data}\n"
return output
if __name__ == "__main__":
dataloader_config = DataLoaderConfiguration(
use_seedable_sampler=True
)
accelerator = Accelerator(
dataloader_config=dataloader_config,
)
set_seed(42)
# Create the dataset and DataLoader with shuffle=True
dataset = SimpleDataset()
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Check data order for 3 epochs
all_outputs = []
for epoch in range(3):
epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
accelerator.wait_for_everyone()
# Resume from checkpoint
accelerator = Accelerator(
dataloader_config=dataloader_config,
)
set_seed(42)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Load random state
if os.path.exists("debug_data_order_checkpoints"):
accelerator.load_state("debug_data_order_checkpoints")
print("Random state loaded from debug_data_order_checkpoints")
# skip_dataloader = accelerator.skip_first_batches(dataloader, 2)
# Check data order for 1 epoch
dataloader.set_epoch(1)
all_outputs = []
for epoch in range(1,3):
epoch_output = print_epoch_batches(epoch, dataloader)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
运行accelerate launch --multi_gpu --num_processes 2 a.py,输出
--- Process 0 ---
Epoch 1:
[[2], [1], [4], [0], [3]]
--- Process 0 ---
Epoch 2:
Random state saved
[[8], [9], [5], [6], [2]]
--- Process 0 ---
Epoch 3:
[[2], [3], [7], [1], [5]]
--- Process 1 ---
Epoch 1:
[[6], [8], [5], [9], [7]]
--- Process 1 ---
Epoch 2:
Random state saved
[[4], [0], [1], [7], [3]]
--- Process 1 ---
Epoch 3:
[[6], [8], [9], [4], [0]]
Random state loaded from debug_data_order_checkpointsRandom state loaded from debug_data_order_checkpoints
--- Process 1 ---
Epoch 2:
[[4], [0], [1], [7], [3]]
--- Process 1 ---
Epoch 3:
[[6], [8], [9], [4], [0]]
--- Process 0 ---
Epoch 2:
[[8], [9], [5], [6], [2]]
--- Process 0 ---
Epoch 3:
[[2], [3], [7], [1], [5]]
可以看出use_seedable_sampler是正常的
Test 5
对于程序:
import torch
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os
# from torchdata.stateful_dataloader import StatefulDataLoader
# Simple dataset with 10 elements
class SimpleDataset(Dataset):
def __init__(self):
self.data = list(range(10))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):
output = f"\n--- Process {accelerator.process_index} ---\n"
output += f"Epoch {epoch + 1}:\n"
data = []
for i, batch in enumerate(dataloader):
data.append(batch.tolist())
if interrupt and epoch == 1 and i == 1:
accelerator.save_state(output_dir="debug_data_order_checkpoints")
output += f"Random state saved\n"
output += f"{data}\n"
return output
if __name__ == "__main__":
dataloader_config = DataLoaderConfiguration(
use_stateful_dataloader=True
)
accelerator = Accelerator(
dataloader_config=dataloader_config,
)
set_seed(42)
# Create the dataset and DataLoader with shuffle=True
dataset = SimpleDataset()
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Check data order for 3 epochs
all_outputs = []
for epoch in range(3):
epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
accelerator.wait_for_everyone()
# Resume from checkpoint
accelerator = Accelerator(
dataloader_config=dataloader_config,
)
set_seed(42)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
)
dataloader = accelerator.prepare(dataloader)
# Load random state
if os.path.exists("debug_data_order_checkpoints"):
accelerator.load_state("debug_data_order_checkpoints")
print("Random state loaded from debug_data_order_checkpoints")
# skip_dataloader = accelerator.skip_first_batches(dataloader, 2)
# Check data order for 1 epoch
all_outputs = []
for epoch in range(1,3):
current_dataloader = dataloader
if epoch == 1:
# skip first 2 batches in epoch 2
current_dataloader = accelerator.skip_first_batches(dataloader, 2)
epoch_output = print_epoch_batches(epoch, current_dataloader)
all_outputs.append(epoch_output)
# Print all outputs at the end
for output in all_outputs:
print(output)
运行accelerate launch --multi_gpu --num_processes 2 a.py,输出
--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]
--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]
--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]
--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]
--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]
--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]
Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints
--- Process 1 ---
Epoch 2:
[[7], [3], [6]]
--- Process 1 ---
Epoch 3:
[[8], [4], [9], [0]]
--- Process 0 ---
Epoch 2:
[[1], [0], [5]]
--- Process 0 ---
Epoch 3:
[[3], [5], [7], [1]]
可以发现Stateful Dataloader与skip_first_batches这类方法不兼容,会导致第二个epoch也有内容被跳过