mirror of
https://github.com/xai-org/grok-1
synced 2024-11-14 05:21:20 +08:00
Update checkpoint.py
This commit is contained in:
parent
7050ed204b
commit
8aac0cea69
@ -42,7 +42,6 @@ sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
|
||||
@contextlib.contextmanager
|
||||
def copy_to_shm(file: str):
|
||||
if file.startswith("/dev/shm/"):
|
||||
# Nothing to do, the file is already in shared memory.
|
||||
yield file
|
||||
return
|
||||
|
||||
@ -81,7 +80,6 @@ def fast_pickle(obj: Any, path: str) -> None:
|
||||
|
||||
|
||||
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
"""Loads a set of arrays."""
|
||||
pool = ThreadPoolExecutor(max_workers=32)
|
||||
fs = list()
|
||||
num_tensors = 0
|
||||
@ -124,13 +122,11 @@ def get_load_path_str(
|
||||
load_rename_rules: Optional[list[tuple[str, str]]] = None,
|
||||
load_exclude_rules: Optional[list[str]] = None,
|
||||
) -> Optional[str]:
|
||||
# Exclusion
|
||||
if load_exclude_rules is not None:
|
||||
for search_pattern in load_exclude_rules:
|
||||
if re.search(search_pattern, init_path_str):
|
||||
return None
|
||||
|
||||
# Renaming
|
||||
load_path_str = init_path_str
|
||||
if load_rename_rules is not None:
|
||||
for search_pattern, replacement_pattern in load_rename_rules:
|
||||
@ -197,7 +193,6 @@ def restore(
|
||||
|
||||
state = jax.tree_util.tree_unflatten(structure, loaded_tensors)
|
||||
|
||||
# Sanity check to give a better error message.
|
||||
ckpt_keys = set(state.params.keys())
|
||||
code_keys = set(state_sharding.params.keys())
|
||||
|
||||
@ -219,3 +214,71 @@ def restore(
|
||||
if params_only:
|
||||
state = state.params
|
||||
return state
|
||||
|
||||
# Database and machine learning integration
|
||||
import sqlite3
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import pandas as pd
|
||||
|
||||
def create_database():
|
||||
conn = sqlite3.connect('data_analysis.db')
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE IF NOT EXISTS data (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
latency REAL,
|
||||
packet_loss REAL)''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def record_data(latency, packet_loss):
|
||||
conn = sqlite3.connect('data_analysis.db')
|
||||
c = conn.cursor()
|
||||
c.execute('INSERT INTO data (latency, packet_loss) VALUES (?, ?)', (latency, packet_loss))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def train_model():
|
||||
conn = sqlite3.connect('data_analysis.db')
|
||||
data = pd.read_sql_query("SELECT * FROM data", conn)
|
||||
conn.close()
|
||||
|
||||
X = data[['latency']]
|
||||
y = data['packet_loss']
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
model = LinearRegression()
|
||||
model.fit(X_train, y_train)
|
||||
return model
|
||||
|
||||
def analyze_task_startup(latency, model):
|
||||
predicted_packet_loss = model.predict([[latency]])[0]
|
||||
if predicted_packet_loss > 10:
|
||||
print("High packet loss predicted: ", predicted_packet_loss)
|
||||
else:
|
||||
print("Packet loss within acceptable range: ", predicted_packet_loss)
|
||||
|
||||
def join_data_with_external_source():
|
||||
external_data = pd.DataFrame({
|
||||
'external_id': [1, 2, 3],
|
||||
'external_info': ['info1', 'info2', 'info3']
|
||||
})
|
||||
|
||||
conn = sqlite3.connect('data_analysis.db')
|
||||
data = pd.read_sql_query("SELECT * FROM data", conn)
|
||||
conn.close()
|
||||
|
||||
joined_data = data.merge(external_data, left_on='id', right_on='external_id')
|
||||
return joined_data
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_database()
|
||||
record_data(50, 5) # Example data
|
||||
record_data(100, 20) # Example data
|
||||
|
||||
model = train_model()
|
||||
analyze_task_startup(70, model)
|
||||
joined_data = join_data_with_external_source()
|
||||
print(joined_data)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user