Skip to content

Commit 672d7e5

Browse files
committed
feat: Add tie_weights parameter to Llava model initialization
1 parent 2037a86 commit 672d7e5

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

lmms_eval/models/llava.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
device_map="cuda:0",
5959
conv_template="vicuna_v1",
6060
use_cache=True,
61+
tie_weights: bool = True,
6162
truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6
6263
customized_config=None, # ends in json
6364
**kwargs,
@@ -97,6 +98,8 @@ def __init__(
9798
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
9899
self._config = self._model.config
99100
self.model.eval()
101+
if tie_weights:
102+
self.model.tie_weights()
100103

101104
self.truncation = truncation
102105
self.batch_size_per_gpu = int(batch_size)

0 commit comments

Comments
 (0)