在accelerate框架中精确恢复dataloader状态的方法


各类方案的比对分析

不难发现,在使用accelerate框架时,accelerator.load_state只能可靠地恢复model、optimizer、scheduler的状态,而dataloader的状态恢复,里面有许多坑。

从原理上来讲,在accelerate框架里,对于一个启用了shuffle的dataloaer,如果要想精确恢复:

  • 要么完全和model、optimizer、scheduler相同,可以保存/加载内部状态 (stateful dataloader)
  • 要么dataloader初始化时从全局随机环境提取一个种子存入内部状态,此后正常进行iteration时不再与全局随机环境发生关联,只依赖内部状态 (seedable sampler)

简单来讲,要么状态可以瞬间重置到正确的位置,要么在快进时不对外输入输出任何影响。

一个典型的反面例子是普通dataloaer+手动skip(纯手写或用skip_first_batches来skip)。普通dataloader在iteration的时候是依赖全局随机环境的,而skip开始前的全局随机环境没有被恢复,skip过程中的全局随机环境变化也与上次运行时的实际情况不符,除非再额外引入一些更加繁琐的机制加以复原,否则这种方法不可能精确恢复dataloader的状态。

accelerate官方文档里主要提到了两类解法:

  • Stateful Dataloader: accelerate虽然声称支持新版torchdata里的StatefulDataloader,但是issue里有大量反馈存在各种兼容性bug。(我自己也碰到过这种兼容性bug,虽然甚至连怎么不兼容都无法准确复现)
  • skip_first_batches: accelerator.skip_first_batches方案看起来稍微好一点,但是实质上和纯手动skip相同,因此同样无法解决随机环境无法复现的问题。

在使用deepwiki查看源码,以及浏览了issue https://github.com/huggingface/accelerate/issues/3242 之后,我发现use_seedable_sampler是一个比较好的方案。虽然 https://github.com/huggingface/accelerate/issues/3398 声称这个也有bug,不过我在v1.11.0里面没有复现,可能已被某个PR修复。(并且这个方案的原理相当简单可靠,没有过于明显的bug之后应该不至于撞上什么特例)

以下是基于accelerate==1.11.0做的验证实验

测试

Test 1

对于程序:

import torch
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os

# from torchdata.stateful_dataloader import StatefulDataLoader

# Simple dataset with 10 elements
class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):

    output = f"\n--- Process {accelerator.process_index} ---\n"
    output += f"Epoch {epoch + 1}:\n"
    data = []
    for i, batch in enumerate(dataloader):
        data.append(batch.tolist())
        if interrupt and epoch == 1 and i == 1:
            accelerator.save_state(output_dir="debug_data_order_checkpoints")
            output += f"Random state saved\n"
    output += f"{data}\n"
    return output

if __name__ == "__main__":
    accelerator = Accelerator()
    set_seed(42)

    # Create the dataset and DataLoader with shuffle=True
    dataset = SimpleDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )
    dataloader = accelerator.prepare(dataloader)

    # Check data order for 3 epochs
    all_outputs = []
    for epoch in range(3):
        epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

    accelerator.wait_for_everyone()

    # Resume from checkpoint
    accelerator = Accelerator(
        # dataloader_config=dataloader_config,
    )
    set_seed(42)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )

    dataloader = accelerator.prepare(dataloader)

    # Load random state
    if os.path.exists("debug_data_order_checkpoints"):
        accelerator.load_state("debug_data_order_checkpoints")
        print("Random state loaded from debug_data_order_checkpoints")

    # skip_dataloader = accelerator.skip_first_batches(dataloader, 2)

    # Check data order for 1 epoch
    all_outputs = []
    for epoch in range(1):
        epoch_output = print_epoch_batches(epoch, dataloader)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

运行accelerate launch --multi_gpu --num_processes 2 a.py,输出:

--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]


--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]


--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]


--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]


--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]


--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]

Random state loaded from debug_data_order_checkpointsRandom state loaded from debug_data_order_checkpoints


--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]


--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]

可以看出load state不会影响普通dataloader的状态

Test 2

对于程序:

import torch
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os

# from torchdata.stateful_dataloader import StatefulDataLoader

# Simple dataset with 10 elements
class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):

    output = f"\n--- Process {accelerator.process_index} ---\n"
    output += f"Epoch {epoch + 1}:\n"
    data = []
    for i, batch in enumerate(dataloader):
        data.append(batch.tolist())
        if interrupt and epoch == 1 and i == 1:
            accelerator.save_state(output_dir="debug_data_order_checkpoints")
            output += f"Random state saved\n"
    output += f"{data}\n"
    return output

if __name__ == "__main__":
    accelerator = Accelerator()
    set_seed(42)

    # Create the dataset and DataLoader with shuffle=True
    dataset = SimpleDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )
    dataloader = accelerator.prepare(dataloader)

    # Check data order for 3 epochs
    all_outputs = []
    for epoch in range(3):
        epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

    accelerator.wait_for_everyone()

    # Resume from checkpoint
    accelerator = Accelerator(
        # dataloader_config=dataloader_config,
    )
    set_seed(42)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )

    dataloader = accelerator.prepare(dataloader)

    # Load random state
    if os.path.exists("debug_data_order_checkpoints"):
        accelerator.load_state("debug_data_order_checkpoints")
        print("Random state loaded from debug_data_order_checkpoints")

    # skip_dataloader = accelerator.skip_first_batches(dataloader, 2)

    # Check data order for 1 epoch
    all_outputs = []
    for epoch in range(1,3):
        current_dataloader = dataloader
        if epoch == 1:
            # skip first 2 batches in epoch 2
            current_dataloader = accelerator.skip_first_batches(dataloader, 2)
        epoch_output = print_epoch_batches(epoch, current_dataloader)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

运行accelerate launch --multi_gpu --num_processes 2 a.py,输出

--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]


--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]


--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]


--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]


--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]


--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]

Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints

--- Process 1 ---
Epoch 2:
[[7], [3], [6]]


--- Process 1 ---
Epoch 3:
[[2], [8], [4], [9], [0]]


--- Process 0 ---
Epoch 2:
[[1], [0], [5]]


--- Process 0 ---
Epoch 3:
[[6], [3], [5], [7], [1]]

可以看出普通DataLoader+skip_first_batches不可能正确恢复epoch不是第一个的情况

Test 3

对于程序:

import torch
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os

# from torchdata.stateful_dataloader import StatefulDataLoader

# Simple dataset with 10 elements
class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):

    output = f"\n--- Process {accelerator.process_index} ---\n"
    output += f"Epoch {epoch + 1}:\n"
    data = []
    for i, batch in enumerate(dataloader):
        data.append(batch.tolist())
        if interrupt and epoch == 1 and i == 1:
            accelerator.save_state(output_dir="debug_data_order_checkpoints")
            output += f"Random state saved\n"
    output += f"{data}\n"
    return output

if __name__ == "__main__":
    dataloader_config = DataLoaderConfiguration(
        use_stateful_dataloader=True
    )
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    set_seed(42)

    # Create the dataset and DataLoader with shuffle=True
    dataset = SimpleDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )
    dataloader = accelerator.prepare(dataloader)

    # Check data order for 3 epochs
    all_outputs = []
    for epoch in range(3):
        epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

    accelerator.wait_for_everyone()

    # Resume from checkpoint
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    set_seed(42)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )

    dataloader = accelerator.prepare(dataloader)

    # Load random state
    if os.path.exists("debug_data_order_checkpoints"):
        accelerator.load_state("debug_data_order_checkpoints")
        print("Random state loaded from debug_data_order_checkpoints")

    # skip_dataloader = accelerator.skip_first_batches(dataloader, 2)

    # Check data order for 1 epoch
    all_outputs = []
    for epoch in range(1,3):
        epoch_output = print_epoch_batches(epoch, dataloader)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

运行accelerate launch --multi_gpu --num_processes 2 a.py,输出

--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]


--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]


--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]


--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]


--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]


--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]

Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints

--- Process 0 ---
Epoch 2:
[[8], [1], [0], [5]]


--- Process 0 ---
Epoch 3:
[[6], [3], [5], [7], [1]]


--- Process 1 ---
Epoch 2:
[[9], [7], [3], [6]]


--- Process 1 ---
Epoch 3:
[[2], [8], [4], [9], [0]]

可以看出StatefulDataloader即使不出现兼容性问题,也无法正确恢复epoch信息,因为此时的dataloader是依赖于全局随机环境的。

Test 4

对于程序:

import torch
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os

# from torchdata.stateful_dataloader import StatefulDataLoader

# Simple dataset with 10 elements
class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):

    output = f"\n--- Process {accelerator.process_index} ---\n"
    output += f"Epoch {epoch + 1}:\n"
    data = []
    for i, batch in enumerate(dataloader):
        data.append(batch.tolist())
        if interrupt and epoch == 1 and i == 1:
            accelerator.save_state(output_dir="debug_data_order_checkpoints")
            output += f"Random state saved\n"
    output += f"{data}\n"
    return output

if __name__ == "__main__":
    dataloader_config = DataLoaderConfiguration(
        use_seedable_sampler=True
    )
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    set_seed(42)

    # Create the dataset and DataLoader with shuffle=True
    dataset = SimpleDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )
    dataloader = accelerator.prepare(dataloader)

    # Check data order for 3 epochs
    all_outputs = []
    for epoch in range(3):
        epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

    accelerator.wait_for_everyone()

    # Resume from checkpoint
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    set_seed(42)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )

    dataloader = accelerator.prepare(dataloader)

    # Load random state
    if os.path.exists("debug_data_order_checkpoints"):
        accelerator.load_state("debug_data_order_checkpoints")
        print("Random state loaded from debug_data_order_checkpoints")

    # skip_dataloader = accelerator.skip_first_batches(dataloader, 2)

    # Check data order for 1 epoch
    dataloader.set_epoch(1)
    all_outputs = []
    for epoch in range(1,3):
        epoch_output = print_epoch_batches(epoch, dataloader)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

运行accelerate launch --multi_gpu --num_processes 2 a.py,输出

--- Process 0 ---
Epoch 1:
[[2], [1], [4], [0], [3]]


--- Process 0 ---
Epoch 2:
Random state saved
[[8], [9], [5], [6], [2]]


--- Process 0 ---
Epoch 3:
[[2], [3], [7], [1], [5]]


--- Process 1 ---
Epoch 1:
[[6], [8], [5], [9], [7]]


--- Process 1 ---
Epoch 2:
Random state saved
[[4], [0], [1], [7], [3]]


--- Process 1 ---
Epoch 3:
[[6], [8], [9], [4], [0]]

Random state loaded from debug_data_order_checkpointsRandom state loaded from debug_data_order_checkpoints


--- Process 1 ---
Epoch 2:
[[4], [0], [1], [7], [3]]


--- Process 1 ---
Epoch 3:
[[6], [8], [9], [4], [0]]


--- Process 0 ---
Epoch 2:
[[8], [9], [5], [6], [2]]


--- Process 0 ---
Epoch 3:
[[2], [3], [7], [1], [5]]

可以看出use_seedable_sampler是正常的

Test 5

对于程序:

import torch
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.utils import set_seed, DataLoaderConfiguration
import os

# from torchdata.stateful_dataloader import StatefulDataLoader

# Simple dataset with 10 elements
class SimpleDataset(Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to print batch order in an epoch
def print_epoch_batches(epoch, dataloader, interrupt=False):

    output = f"\n--- Process {accelerator.process_index} ---\n"
    output += f"Epoch {epoch + 1}:\n"
    data = []
    for i, batch in enumerate(dataloader):
        data.append(batch.tolist())
        if interrupt and epoch == 1 and i == 1:
            accelerator.save_state(output_dir="debug_data_order_checkpoints")
            output += f"Random state saved\n"
    output += f"{data}\n"
    return output

if __name__ == "__main__":
    dataloader_config = DataLoaderConfiguration(
        use_stateful_dataloader=True
    )
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    set_seed(42)

    # Create the dataset and DataLoader with shuffle=True
    dataset = SimpleDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )
    dataloader = accelerator.prepare(dataloader)

    # Check data order for 3 epochs
    all_outputs = []
    for epoch in range(3):
        epoch_output = print_epoch_batches(epoch, dataloader, interrupt=True)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

    accelerator.wait_for_everyone()

    # Resume from checkpoint
    accelerator = Accelerator(
        dataloader_config=dataloader_config,
    )
    set_seed(42)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
    )

    dataloader = accelerator.prepare(dataloader)

    # Load random state
    if os.path.exists("debug_data_order_checkpoints"):
        accelerator.load_state("debug_data_order_checkpoints")
        print("Random state loaded from debug_data_order_checkpoints")

    # skip_dataloader = accelerator.skip_first_batches(dataloader, 2)

    # Check data order for 1 epoch
    all_outputs = []
    for epoch in range(1,3):
        current_dataloader = dataloader
        if epoch == 1:
            # skip first 2 batches in epoch 2
            current_dataloader = accelerator.skip_first_batches(dataloader, 2)
        epoch_output = print_epoch_batches(epoch, current_dataloader)
        all_outputs.append(epoch_output)

    # Print all outputs at the end
    for output in all_outputs:
        print(output)

运行accelerate launch --multi_gpu --num_processes 2 a.py,输出

--- Process 1 ---
Epoch 1:
[[2], [9], [7], [3], [6]]


--- Process 1 ---
Epoch 2:
Random state saved
[[2], [8], [4], [9], [0]]


--- Process 1 ---
Epoch 3:
[[4], [0], [8], [6], [5]]


--- Process 0 ---
Epoch 1:
[[4], [8], [1], [0], [5]]


--- Process 0 ---
Epoch 2:
Random state saved
[[6], [3], [5], [7], [1]]


--- Process 0 ---
Epoch 3:
[[3], [9], [1], [7], [2]]

Random state loaded from debug_data_order_checkpoints
Random state loaded from debug_data_order_checkpoints

--- Process 1 ---
Epoch 2:
[[7], [3], [6]]


--- Process 1 ---
Epoch 3:
[[8], [4], [9], [0]]


--- Process 0 ---
Epoch 2:
[[1], [0], [5]]


--- Process 0 ---
Epoch 3:
[[3], [5], [7], [1]]

可以发现Stateful Dataloader与skip_first_batches这类方法不兼容,会导致第二个epoch也有内容被跳过