Skip to content

Commit a4d5fd9

Browse files
authored
Create branch if missing in hugginface-cli upload (#1857)
* Create branch if missing in hugginface-cli upload * make test more robust
1 parent 793d691 commit a4d5fd9

File tree

4 files changed

+82
-7
lines changed

4 files changed

+82
-7
lines changed

docs/source/en/guides/cli.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ By default, files are uploaded to the `main` branch. If you want to upload files
332332
...
333333
```
334334

335+
**Note:** if `revision` does not exist and `--create-pr` is not set, a branch will be created automatically from the `main` branch.
336+
335337
### Upload and create a PR
336338

337339
If you don't have the permission to push to a repo, you must open a PR and let the authors know about the changes you want to make. This can be done by setting the `--create-pr` option:

src/huggingface_hub/commands/upload.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from huggingface_hub.commands import BaseHuggingfaceCLICommand
5454
from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
5555
from huggingface_hub.hf_api import HfApi
56-
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
56+
from huggingface_hub.utils import RevisionNotFoundError, disable_progress_bars, enable_progress_bars
5757

5858

5959
logger = logging.get_logger(__name__)
@@ -83,7 +83,10 @@ def register_subcommand(parser: _SubParsersAction):
8383
upload_parser.add_argument(
8484
"--revision",
8585
type=str,
86-
help="An optional Git revision id which can be a branch name, a tag, or a commit hash.",
86+
help=(
87+
"An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not"
88+
" exist and `--create-pr` is not set, a branch will be automatically created."
89+
),
8790
)
8891
upload_parser.add_argument(
8992
"--private",
@@ -255,6 +258,15 @@ def _upload(self) -> str:
255258
# ^ I'd rather not add CLI args to set it explicitly as we already have `huggingface-cli repo create` for that.
256259
).repo_id
257260

261+
# Check if branch already exists and if not, create it
262+
if self.revision is not None and not self.create_pr:
263+
try:
264+
self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision)
265+
except RevisionNotFoundError:
266+
logger.info(f"Branch '{self.revision}' not found. Creating it...")
267+
self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True)
268+
# ^ `exist_ok=True` to avoid race concurrency issues
269+
258270
# File-based upload
259271
if os.path.isfile(self.local_path):
260272
return self.api.upload_file(

tests/test_cli.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from huggingface_hub.commands.download import DownloadCommand
1212
from huggingface_hub.commands.scan_cache import ScanCacheCommand
1313
from huggingface_hub.commands.upload import UploadCommand
14-
from huggingface_hub.utils import SoftTemporaryDirectory, capture_output
14+
from huggingface_hub.utils import RevisionNotFoundError, SoftTemporaryDirectory, capture_output
1515

1616
from .testing_utils import DUMMY_MODEL_ID
1717

@@ -211,9 +211,10 @@ def test_every_as_float(self) -> None:
211211
cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, ".", "--every", "0.5"]))
212212
self.assertEqual(cmd.every, 0.5)
213213

214+
@patch("huggingface_hub.commands.upload.HfApi.repo_info")
214215
@patch("huggingface_hub.commands.upload.HfApi.upload_folder")
215216
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
216-
def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
217+
def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None:
217218
with SoftTemporaryDirectory() as cache_dir:
218219
cmd = UploadCommand(
219220
self.parser.parse_args(
@@ -239,9 +240,10 @@ def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
239240
delete_patterns=["*.json"],
240241
)
241242

243+
@patch("huggingface_hub.commands.upload.HfApi.repo_info")
242244
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
243245
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
244-
def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
246+
def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None:
245247
with SoftTemporaryDirectory() as cache_dir:
246248
file_path = Path(cache_dir) / "file.txt"
247249
file_path.write_text("content")
@@ -266,6 +268,65 @@ def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
266268
create_pr=True,
267269
)
268270

271+
@patch("huggingface_hub.commands.upload.HfApi.repo_info")
272+
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
273+
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
274+
def test_upload_file_no_revision_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None:
275+
with SoftTemporaryDirectory() as cache_dir:
276+
file_path = Path(cache_dir) / "file.txt"
277+
file_path.write_text("content")
278+
cmd = UploadCommand(self.parser.parse_args(["upload", "my-model", str(file_path), "logs/file.txt"]))
279+
cmd.run()
280+
# Revision not specified => no need to check
281+
repo_info_mock.assert_not_called()
282+
283+
@patch("huggingface_hub.commands.upload.HfApi.create_branch")
284+
@patch("huggingface_hub.commands.upload.HfApi.repo_info")
285+
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
286+
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
287+
def test_upload_file_with_revision_mock(
288+
self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock, create_branch_mock: Mock
289+
) -> None:
290+
repo_info_mock.side_effect = RevisionNotFoundError("revision not found")
291+
292+
with SoftTemporaryDirectory() as cache_dir:
293+
file_path = Path(cache_dir) / "file.txt"
294+
file_path.write_text("content")
295+
cmd = UploadCommand(
296+
self.parser.parse_args(
297+
["upload", "my-model", str(file_path), "logs/file.txt", "--revision", "my-branch"]
298+
)
299+
)
300+
cmd.run()
301+
302+
# Revision specified => check that it exists
303+
repo_info_mock.assert_called_once_with(
304+
repo_id=create_mock.return_value.repo_id, repo_type="model", revision="my-branch"
305+
)
306+
307+
# Revision does not exist => create it
308+
create_branch_mock.assert_called_once_with(
309+
repo_id=create_mock.return_value.repo_id, repo_type="model", branch="my-branch", exist_ok=True
310+
)
311+
312+
@patch("huggingface_hub.commands.upload.HfApi.repo_info")
313+
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
314+
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
315+
def test_upload_file_revision_and_create_pr_mock(
316+
self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock
317+
) -> None:
318+
with SoftTemporaryDirectory() as cache_dir:
319+
file_path = Path(cache_dir) / "file.txt"
320+
file_path.write_text("content")
321+
cmd = UploadCommand(
322+
self.parser.parse_args(
323+
["upload", "my-model", str(file_path), "logs/file.txt", "--revision", "my-branch", "--create-pr"]
324+
)
325+
)
326+
cmd.run()
327+
# Revision specified but --create-pr => no need to check
328+
repo_info_mock.assert_not_called()
329+
269330
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
270331
def test_upload_missing_path(self, create_mock: Mock) -> None:
271332
cmd = UploadCommand(self.parser.parse_args(["upload", "my-model", "/path/to/missing_file", "logs/file.txt"]))

tests/test_hf_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,12 +1563,12 @@ def test_create_branch_from_revision(self, repo_url: RepoUrl) -> None:
15631563
"""Test `create_branch` from a different revision than main HEAD."""
15641564
# Create commit and remember initial/latest commit
15651565
initial_commit = self._api.model_info(repo_url.repo_id).sha
1566-
self._api.create_commit(
1566+
commit = self._api.create_commit(
15671567
repo_url.repo_id,
15681568
operations=[CommitOperationAdd(path_in_repo="app.py", path_or_fileobj=b"content")],
15691569
commit_message="test commit",
15701570
)
1571-
latest_commit = self._api.model_info(repo_url.repo_id).sha
1571+
latest_commit = commit.oid
15721572

15731573
# Create branches
15741574
self._api.create_branch(repo_url.repo_id, branch="from-head")

0 commit comments

Comments
 (0)