def _remap_general_mistral_args(config: dict) -> dict:
# Mistral key -> HF key
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
# HF key -> (Mistral key, default value)
top_level_mapping_with_default = {
"model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", 128_000),
"max_position_embeddings": ("max_position_embeddings", 128_000),
}
for key, new_key in config_mapping.items():
if key in config:
config[new_key] = config.pop(key)
for new_key, (key,
default_value) in top_level_mapping_with_default.items():
config[new_key] = config.pop(key, default_value)
return config