Skip to content

Commit 207d2f8

Browse files
committed
Make code-llama and hf-tgi inference runnable as module
1 parent 5ac5d99 commit 207d2f8

File tree

10 files changed

+23
-9
lines changed

10 files changed

+23
-9
lines changed
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
import fire
5+
6+
from .inference import main
7+
8+
if __name__ == "__main__":
9+
fire.Fire(main)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

src/llama_recipes/inference/code-llama/code_completion_example.py renamed to src/llama_recipes/inference/code_llama/code_completion_example.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
55

66
import fire
7-
import torch
87
import os
98
import sys
109
import time
11-
from typing import List
1210

11+
import torch
1312
from transformers import AutoTokenizer
14-
sys.path.append("..")
15-
from safety_utils import get_safety_checker
16-
from model_utils import load_model, load_peft_model, load_llama_from_config
13+
14+
from ..safety_utils import get_safety_checker
15+
from ..model_utils import load_model, load_peft_model
16+
1717

1818
def main(
1919
model_name,

src/llama_recipes/inference/code-llama/code_infilling_example.py renamed to src/llama_recipes/inference/code_llama/code_infilling_example.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
import os
99
import sys
1010
import time
11-
from typing import List
1211

1312
from transformers import AutoTokenizer
14-
sys.path.append("..")
15-
from safety_utils import get_safety_checker
16-
from model_utils import load_model, load_peft_model, load_llama_from_config
13+
14+
from ..safety_utils import get_safety_checker
15+
from ..model_utils import load_model, load_peft_model
1716

1817
def main(
1918
model_name,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

0 commit comments

Comments
 (0)