|
188 | 188 | "text"
|
189 | 189 | ]
|
190 | 190 | },
|
| 191 | + { |
| 192 | + "metadata": {}, |
| 193 | + "cell_type": "markdown", |
| 194 | + "source": "## Initializing TMaRCo", |
| 195 | + "id": "1eb7719e30054304" |
| 196 | + }, |
191 | 197 | {
|
192 | 198 | "cell_type": "code",
|
193 | 199 | "execution_count": 5,
|
|
198 | 204 | "tmarco = TMaRCo()"
|
199 | 205 | ]
|
200 | 206 | },
|
| 207 | + { |
| 208 | + "metadata": {}, |
| 209 | + "cell_type": "markdown", |
| 210 | + "source": [ |
| 211 | + "This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n", |
| 212 | + "<div class=\"alert alert-info\">\n", |
| 213 | + "To use local models with TMaRCo, we need to have them in a local storage, accessible to TMaRCo, initialize separately, and pass them to TMaRCo's constructor.\n", |
| 214 | + "</div>\n", |
| 215 | + "For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:" |
| 216 | + ], |
| 217 | + "id": "3e16ee305f4983d9" |
| 218 | + }, |
| 219 | + { |
| 220 | + "metadata": {}, |
| 221 | + "cell_type": "code", |
| 222 | + "outputs": [], |
| 223 | + "execution_count": null, |
| 224 | + "source": [ |
| 225 | + "from huggingface_hub import snapshot_download\n", |
| 226 | + "\n", |
| 227 | + "snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")" |
| 228 | + ], |
| 229 | + "id": "614c9ff6f46a0ea9" |
| 230 | + }, |
| 231 | + { |
| 232 | + "metadata": {}, |
| 233 | + "cell_type": "markdown", |
| 234 | + "source": "We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:", |
| 235 | + "id": "95bd792e757205d6" |
| 236 | + }, |
| 237 | + { |
| 238 | + "metadata": {}, |
| 239 | + "cell_type": "code", |
| 240 | + "source": [ |
| 241 | + "from transformers import BartForConditionalGeneration, BartTokenizer\n", |
| 242 | + "\n", |
| 243 | + "tokenizer = BartTokenizer.from_pretrained(\n", |
| 244 | + " \"models/bart\", # Or directory where the local model is stored \n", |
| 245 | + " is_split_into_words=True, add_prefix_space=True\n", |
| 246 | + ")\n", |
| 247 | + "\n", |
| 248 | + "tokenizer.pad_token_id = tokenizer.eos_token_id\n", |
| 249 | + "\n", |
| 250 | + "base = BartForConditionalGeneration.from_pretrained(\n", |
| 251 | + " \"models/bart\", # Or directory where the local model is stored\n", |
| 252 | + " max_length=150,\n", |
| 253 | + " forced_bos_token_id=tokenizer.bos_token_id,\n", |
| 254 | + ")\n", |
| 255 | + "\n", |
| 256 | + "# Initialize TMaRCo with local models\n", |
| 257 | + "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)" |
| 258 | + ], |
| 259 | + "id": "f0f24485822a7c3f", |
| 260 | + "outputs": [], |
| 261 | + "execution_count": null |
| 262 | + }, |
201 | 263 | {
|
202 | 264 | "cell_type": "code",
|
203 | 265 | "execution_count": 7,
|
|
223 | 285 | "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
|
224 | 286 | ]
|
225 | 287 | },
|
| 288 | + { |
| 289 | + "metadata": {}, |
| 290 | + "cell_type": "markdown", |
| 291 | + "source": [ |
| 292 | + "<div class=\"alert alert-info\">\n", |
| 293 | + "To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage, accessible to TMaRCo, as previously.\n", |
| 294 | + "However, we don't need to initialize them separately, and can pass the directory directly.\n", |
| 295 | + "</div>\n", |
| 296 | + "If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n" |
| 297 | + ], |
| 298 | + "id": "c113208c527c342e" |
| 299 | + }, |
| 300 | + { |
| 301 | + "metadata": {}, |
| 302 | + "cell_type": "code", |
| 303 | + "outputs": [], |
| 304 | + "execution_count": null, |
| 305 | + "source": [ |
| 306 | + "snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n", |
| 307 | + "snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n", |
| 308 | + "\n", |
| 309 | + "tmarco.load_models([\"models/gminus\", \"models/gplus\"])" |
| 310 | + ], |
| 311 | + "id": "dfa288dcb60102c" |
| 312 | + }, |
226 | 313 | {
|
227 | 314 | "cell_type": "code",
|
228 | 315 | "execution_count": 13,
|
|
362 | 449 | "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
|
363 | 450 | ]
|
364 | 451 | },
|
| 452 | + { |
| 453 | + "metadata": {}, |
| 454 | + "cell_type": "markdown", |
| 455 | + "source": "As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:", |
| 456 | + "id": "b0738c324227f57" |
| 457 | + }, |
| 458 | + { |
| 459 | + "metadata": {}, |
| 460 | + "cell_type": "code", |
| 461 | + "outputs": [], |
| 462 | + "execution_count": null, |
| 463 | + "source": [ |
| 464 | + "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n", |
| 465 | + "tmarco.load_models([\"models/gminus\", \"models/gplus\"])" |
| 466 | + ], |
| 467 | + "id": "b929e21a97ea914e" |
| 468 | + }, |
365 | 469 | {
|
366 | 470 | "cell_type": "markdown",
|
367 | 471 | "id": "5303f56b-85ff-40da-99bf-6962cf2f3395",
|
|
0 commit comments