Allows CPU-based execution

This commit is contained in:
louiehelm 2024-03-20 08:07:43 +05:00
parent 7050ed204b
commit 1101257c05
2 changed files with 27 additions and 1 deletions

4
requirements-cpu.txt Normal file
View File

@ -0,0 +1,4 @@
dm_haiku==0.0.12
jax==0.4.25
numpy==1.26.4
sentencepiece==0.2.0

24
run.py
View File

@ -12,11 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import logging, os
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
# Fall back to using CPU execution if less than 8 GPUs
# ONLY MEANT FOR DEVELOPERS WITH 384GB RAM
# CURRENTLY TOO SLOW FOR MEANINGFUL INFERENCE WORKLOADS
#
# Set True to run model on CPU only
USE_CPU_ONLY = False
if USE_CPU_ONLY:
# Simulate 8 devices via CPUs
xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_force_host_platform_device_count=8"
os.environ["XLA_FLAGS"] = xla_flags
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Suppress warnings about unused backends
logging.getLogger("jax._src.xla_bridge").addFilter(logging.Filter("Unable to initialize backend"))
# Suppress false warnings about stuck processes
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("This thread has been waiting for"))
logging.getLogger("collective_ops_utils").addFilter(logging.Filter("Thread is unstuck"))
# Suppress warnings about slow compiling
logging.getLogger("slow_operation_alarm").addFilter(logging.Filter("Very slow compile"))
CKPT_PATH = "./checkpoints/"