mirror of
https://github.com/xai-org/grok-1
synced 2024-11-12 20:21:19 +08:00
Allows CPU-based execution
This commit is contained in:
parent
7050ed204b
commit
1101257c05
4
requirements-cpu.txt
Normal file
4
requirements-cpu.txt
Normal file
@ -0,0 +1,4 @@
|
||||
dm_haiku==0.0.12
|
||||
jax==0.4.25
|
||||
numpy==1.26.4
|
||||
sentencepiece==0.2.0
|
24
run.py
24
run.py
@ -12,11 +12,33 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import logging, os
|
||||
|
||||
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
|
||||
from runners import InferenceRunner, ModelRunner, sample_from_model
|
||||
|
||||
# Fall back to using CPU execution if less than 8 GPUs
|
||||
# ONLY MEANT FOR DEVELOPERS WITH 384GB RAM
|
||||
# CURRENTLY TOO SLOW FOR MEANINGFUL INFERENCE WORKLOADS
|
||||
#
|
||||
# Set True to run model on CPU only
|
||||
USE_CPU_ONLY = False
|
||||
|
||||
if USE_CPU_ONLY:
|
||||
# Simulate 8 devices via CPUs
|
||||
xla_flags = os.environ.get("XLA_FLAGS", "")
|
||||
xla_flags += " --xla_force_host_platform_device_count=8"
|
||||
os.environ["XLA_FLAGS"] = xla_flags
|
||||
# Enforce CPU-only execution
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
# Suppress warnings about unused backends
|
||||
logging.getLogger("jax._src.xla_bridge").addFilter(logging.Filter("Unable to initialize backend"))
|
||||
# Suppress false warnings about stuck processes
|
||||
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("This thread has been waiting for"))
|
||||
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("Thread is unstuck"))
|
||||
# Suppress warnings about slow compiling
|
||||
logging.getLogger("slow_operation_alarm").addFilter(logging.Filter("Very slow compile"))
|
||||
|
||||
|
||||
CKPT_PATH = "./checkpoints/"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user