diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..a5452c2 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -39,8 +39,19 @@ rank_logger = logging.getLogger("rank") sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit +# Utility functions for file handling and shared memory + @contextlib.contextmanager def copy_to_shm(file: str): + """ + Context manager to copy a file to shared memory. + + Args: + file (str): The path to the file to be copied. + + Yields: + str: The path to the copied file in shared memory. + """ if file.startswith("/dev/shm/"): # Nothing to do, the file is already in shared memory. yield file @@ -58,6 +69,15 @@ def copy_to_shm(file: str): @contextlib.contextmanager def copy_from_shm(file: str): + """ + Context manager to copy a file from shared memory. + + Args: + file (str): The path to the file to be copied. + + Yields: + str: The path to the temporary file in shared memory. + """ tmp_dir = "/dev/shm/" fd, tmp_path = tempfile.mkstemp(dir=tmp_dir) try: @@ -69,19 +89,48 @@ def copy_from_shm(file: str): def fast_unpickle(path: str) -> Any: + """ + Unpickle an object from a file using shared memory for faster loading. + + Args: + path (str): The path to the file containing the pickled object. + + Returns: + Any: The unpickled object. + """ with copy_to_shm(path) as tmp_path: with open(tmp_path, "rb") as f: return pickle.load(f) def fast_pickle(obj: Any, path: str) -> None: + """ + Pickle an object to a file using shared memory for faster saving. + + Args: + obj (Any): The object to be pickled. + path (str): The path to the file where the object will be saved. + """ with copy_from_shm(path) as tmp_path: with open(tmp_path, "wb") as f: pickle.dump(obj, f) +# Tensor loading and path handling + def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): - """Loads a set of arrays.""" + """ + Load a set of arrays from files in parallel using a thread pool. + + Args: + shaped_arrays (list): A list of shaped arrays to be loaded. + directory (str): The directory containing the tensor files. + mesh_config (tuple): The mesh configuration. + tensor_indices (list, optional): The indices of the tensors to load. Defaults to None. + + Returns: + list: A list of loaded arrays. + """ pool = ThreadPoolExecutor(max_workers=32) fs = list() num_tensors = 0 @@ -108,6 +157,15 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): def path_tuple_to_string(path: tuple) -> str: + """ + Convert a path tuple to a string representation. + + Args: + path (tuple): The path tuple. + + Returns: + str: The string representation of the path. + """ pieces = [] for elem in path: if isinstance(elem, jax.tree_util.DictKey): @@ -124,6 +182,17 @@ def get_load_path_str( load_rename_rules: Optional[list[tuple[str, str]]] = None, load_exclude_rules: Optional[list[str]] = None, ) -> Optional[str]: + """ + Get the load path string based on the initial path string and renaming/exclusion rules. + + Args: + init_path_str (str): The initial path string. + load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None. + load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None. + + Returns: + Optional[str]: The load path string if not excluded, otherwise None. + """ # Exclusion if load_exclude_rules is not None: for search_pattern in load_exclude_rules: @@ -148,6 +217,19 @@ def replace_with_load_state( load_exclude_rules: Optional[list[str]] = None, mesh_config: tuple = (1, 1), ) -> Any: + """ + Replace the initial state with the loaded state based on renaming and exclusion rules. + + Args: + init_state (Any): The initial state. + load_state (Any): The loaded state. + load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None. + load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None. + mesh_config (tuple, optional): The mesh configuration. Defaults to (1, 1). + + Returns: + Any: The replaced state. + """ flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state) flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state) load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load} @@ -177,6 +259,8 @@ def replace_with_load_state( return jax.tree_util.tree_unflatten(structure_init, replaced) +# Checkpoint restoration + def restore( checkpoint_path: str, state_shapes: Any, @@ -186,6 +270,21 @@ def restore( state_sharding, init_state: Optional[Any] = None, ) -> Any: + """ + Restore the state from a checkpoint. + + Args: + checkpoint_path (str): The path to the checkpoint directory. + state_shapes (Any): The shapes of the state. + mesh: The mesh configuration. + between_hosts_config: The configuration for communication between hosts. + params_only (bool): Whether to restore only the parameters. + state_sharding: The sharding specification for the state. + init_state (Optional[Any], optional): The initial state. Defaults to None. + + Returns: + Any: The restored state. + """ ckpt_path = os.path.join(checkpoint_path, "ckpt-0") rank_logger.info("Loading checkpoint at {}".format(ckpt_path))