Update checkpoint.py

This commit is contained in:
Yahweh Rapha Bradford 2024-05-07 01:51:50 -04:00 committed by GitHub
parent 8f05ad77cf
commit 7a19c9eb9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,16 +1,4 @@
# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations from __future__ import annotations
@ -213,7 +201,7 @@ def restore(
state_sharding = jax.tree_util.tree_map( state_sharding = jax.tree_util.tree_map(
lambda x: jax.sharding.PartitionSpec() if x is None else x, lambda x: jax.sharding.PartitionSpec() if x is None else x,
state_sharding, state_sharding,
is_leaf=lambda x: x is None, is_leaf=lambda is None,
) )
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
if params_only: if params_only: