def save_checkpoint(path: Text, epoch: int, modules: Union[nn.Module, Sequence[nn.Module]], optimizers: Union[opt.Optimizer, Sequence[opt.Optimizer]], safe_replacement: bool = True): """ Save a checkpoint of the current state of the training, so it can be resumed. This checkpointing function assumes that there are no learning rate schedulers or gradient scalers for automatic mixed precision. :param path: Path for your checkpoint file :param epoch: Current (completed) epoch :param modules: nn.Module containing the model or a list of nn.Module objects :param optimizers: Optimizer or list of optimizers :param safe_replacement: Keep old checkpoint until the new one has been completed :return: """ # This function can be called both as # save_checkpoint('/my/checkpoint/path.pth', my_epoch, my_module, my_opt) # or # save_checkpoint('/my/checkpoint/path.pth', my_epoch, [my_module1, my_module2], [my_opt1, my_opt2]) if isinstance(modules, nn.Module): modules = [modules] if isinstance(optimizers, opt.Optimizer): optimizers = [optimizers] # Data dictionary to be saved data = { 'epoch': epoch, # Current time (UNIX timestamp) 'time': time.time(), # State dict for all the modules 'modules': [m.state_dict() for m in modules], # State dict for all the optimizers 'optimizers': [o.state_dict() for o in optimizers] } # Safe replacement of old checkpoint temp_file = None if os.path.exists(path) and safe_replacement: # There's an old checkpoint. Rename it! temp_file = path + '.old' os.rename(path, temp_file) # Save the new checkpoint with open(path, 'wb') as fp: torch.save(data, fp) # Flush and sync the FS fp.flush() os.fsync(fp.fileno()) # Remove the old checkpoint if temp_file is not None: os.unlink(path + '.old') def load_checkpoint(path: Text, default_epoch: int, modules: Union[nn.Module, Sequence[nn.Module]], optimizers: Union[opt.Optimizer, Sequence[opt.Optimizer]], verbose: bool = True): """ Try to load a checkpoint to resume the training. :param path: Path for your checkpoint file :param default_epoch: Initial value for "epoch" (in case there are not snapshots) :param modules: nn.Module containing the model or a list of nn.Module objects. They are assumed to stay on the same device :param optimizers: Optimizer or list of optimizers :param verbose: Verbose mode :return: Next epoch """ if isinstance(modules, nn.Module): modules = [modules] if isinstance(optimizers, opt.Optimizer): optimizers = [optimizers] # If there's a checkpoint if os.path.exists(path): # Load data data = torch.load(path, map_location=next(modules[0].parameters()).device) # Inform the user that we are loading the checkpoint if verbose: print(f"Loaded checkpoint saved at {datetime.fromtimestamp(data['time']).strftime('%Y-%m-%d %H:%M:%S')}. " f"Resuming from epoch {data['epoch']}") # Load state for all the modules for i, m in enumerate(modules): modules[i].load_state_dict(data['modules'][i]) # Load state for all the optimizers for i, o in enumerate(optimizers): optimizers[i].load_state_dict(data['optimizers'][i]) # Next epoch return data['epoch'] + 1 else: return default_epoch