Skip to content

Commit 1a7a0b7

Browse files
committed
Add notes on using local models
1 parent 16ce8df commit 1a7a0b7

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

Diff for: examples/Detoxify.ipynb

+104
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@
188188
"text"
189189
]
190190
},
191+
{
192+
"metadata": {},
193+
"cell_type": "markdown",
194+
"source": "## Initializing TMaRCo",
195+
"id": "1eb7719e30054304"
196+
},
191197
{
192198
"cell_type": "code",
193199
"execution_count": 5,
@@ -198,6 +204,62 @@
198204
"tmarco = TMaRCo()"
199205
]
200206
},
207+
{
208+
"metadata": {},
209+
"cell_type": "markdown",
210+
"source": [
211+
"This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n",
212+
"<div class=\"alert alert-info\">\n",
213+
"To use local models with TMaRCo, we need to have them in a local storage, accessible to TMaRCo, initialize separately, and pass them to TMaRCo's constructor.\n",
214+
"</div>\n",
215+
"For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:"
216+
],
217+
"id": "3e16ee305f4983d9"
218+
},
219+
{
220+
"metadata": {},
221+
"cell_type": "code",
222+
"outputs": [],
223+
"execution_count": null,
224+
"source": [
225+
"from huggingface_hub import snapshot_download\n",
226+
"\n",
227+
"snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")"
228+
],
229+
"id": "614c9ff6f46a0ea9"
230+
},
231+
{
232+
"metadata": {},
233+
"cell_type": "markdown",
234+
"source": "We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:",
235+
"id": "95bd792e757205d6"
236+
},
237+
{
238+
"metadata": {},
239+
"cell_type": "code",
240+
"source": [
241+
"from transformers import BartForConditionalGeneration, BartTokenizer\n",
242+
"\n",
243+
"tokenizer = BartTokenizer.from_pretrained(\n",
244+
" \"models/bart\", # Or directory where the local model is stored \n",
245+
" is_split_into_words=True, add_prefix_space=True\n",
246+
")\n",
247+
"\n",
248+
"tokenizer.pad_token_id = tokenizer.eos_token_id\n",
249+
"\n",
250+
"base = BartForConditionalGeneration.from_pretrained(\n",
251+
" \"models/bart\", # Or directory where the local model is stored\n",
252+
" max_length=150,\n",
253+
" forced_bos_token_id=tokenizer.bos_token_id,\n",
254+
")\n",
255+
"\n",
256+
"# Initialize TMaRCo with local models\n",
257+
"tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)"
258+
],
259+
"id": "f0f24485822a7c3f",
260+
"outputs": [],
261+
"execution_count": null
262+
},
201263
{
202264
"cell_type": "code",
203265
"execution_count": 7,
@@ -223,6 +285,31 @@
223285
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
224286
]
225287
},
288+
{
289+
"metadata": {},
290+
"cell_type": "markdown",
291+
"source": [
292+
"<div class=\"alert alert-info\">\n",
293+
"To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage, accessible to TMaRCo, as previously.\n",
294+
"However, we don't need to initialize them separately, and can pass the directory directly.\n",
295+
"</div>\n",
296+
"If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n"
297+
],
298+
"id": "c113208c527c342e"
299+
},
300+
{
301+
"metadata": {},
302+
"cell_type": "code",
303+
"outputs": [],
304+
"execution_count": null,
305+
"source": [
306+
"snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n",
307+
"snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n",
308+
"\n",
309+
"tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
310+
],
311+
"id": "dfa288dcb60102c"
312+
},
226313
{
227314
"cell_type": "code",
228315
"execution_count": 13,
@@ -362,6 +449,23 @@
362449
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
363450
]
364451
},
452+
{
453+
"metadata": {},
454+
"cell_type": "markdown",
455+
"source": "As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:",
456+
"id": "b0738c324227f57"
457+
},
458+
{
459+
"metadata": {},
460+
"cell_type": "code",
461+
"outputs": [],
462+
"execution_count": null,
463+
"source": [
464+
"tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n",
465+
"tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
466+
],
467+
"id": "b929e21a97ea914e"
468+
},
365469
{
366470
"cell_type": "markdown",
367471
"id": "5303f56b-85ff-40da-99bf-6962cf2f3395",

0 commit comments

Comments
 (0)