diff --git a/model.py b/model.py index 87d700d..ff08762 100644 --- a/model.py +++ b/model.py @@ -90,9 +90,7 @@ def ffn_size(emb_size, widening_factor): def apply_rules(rules): - def _apply_rules(path, value): - del value # Unused. - + def _apply_rules(path, _): path_list = [str(i.key).split("/") for i in path if isinstance(i, jax.tree_util.DictKey)] flattened_path = jax.tree_util.tree_flatten(path_list)[0]