Skip to content

Create branch if missing in hugginface-cli upload #1857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ By default, files are uploaded to the `main` branch. If you want to upload files
...
```

**Note:** if `revision` does not exist and `--create-pr` is not set, a branch will be created automatically from the `main` branch.

### Upload and create a PR

If you don't have the permission to push to a repo, you must open a PR and let the authors know about the changes you want to make. This can be done by setting the `--create-pr` option:
Expand Down
16 changes: 14 additions & 2 deletions src/huggingface_hub/commands/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
from huggingface_hub.utils import RevisionNotFoundError, disable_progress_bars, enable_progress_bars


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -83,7 +83,10 @@ def register_subcommand(parser: _SubParsersAction):
upload_parser.add_argument(
"--revision",
type=str,
help="An optional Git revision id which can be a branch name, a tag, or a commit hash.",
help=(
"An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not"
" exist and `--create-pr` is not set, a branch will be automatically created."
),
)
upload_parser.add_argument(
"--private",
Expand Down Expand Up @@ -255,6 +258,15 @@ def _upload(self) -> str:
# ^ I'd rather not add CLI args to set it explicitly as we already have `huggingface-cli repo create` for that.
).repo_id

# Check if branch already exists and if not, create it
if self.revision is not None and not self.create_pr:
try:
self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision)
except RevisionNotFoundError:
logger.info(f"Branch '{self.revision}' not found. Creating it...")
self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True)
# ^ `exist_ok=True` to avoid race concurrency issues

# File-based upload
if os.path.isfile(self.local_path):
return self.api.upload_file(
Expand Down
67 changes: 64 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from huggingface_hub.commands.download import DownloadCommand
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.commands.upload import UploadCommand
from huggingface_hub.utils import SoftTemporaryDirectory, capture_output
from huggingface_hub.utils import RevisionNotFoundError, SoftTemporaryDirectory, capture_output

from .testing_utils import DUMMY_MODEL_ID

Expand Down Expand Up @@ -211,9 +211,10 @@ def test_every_as_float(self) -> None:
cmd = UploadCommand(self.parser.parse_args(["upload", DUMMY_MODEL_ID, ".", "--every", "0.5"]))
self.assertEqual(cmd.every, 0.5)

@patch("huggingface_hub.commands.upload.HfApi.repo_info")
@patch("huggingface_hub.commands.upload.HfApi.upload_folder")
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None:
with SoftTemporaryDirectory() as cache_dir:
cmd = UploadCommand(
self.parser.parse_args(
Expand All @@ -239,9 +240,10 @@ def test_upload_folder_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
delete_patterns=["*.json"],
)

@patch("huggingface_hub.commands.upload.HfApi.repo_info")
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None:
with SoftTemporaryDirectory() as cache_dir:
file_path = Path(cache_dir) / "file.txt"
file_path.write_text("content")
Expand All @@ -266,6 +268,65 @@ def test_upload_file_mock(self, create_mock: Mock, upload_mock: Mock) -> None:
create_pr=True,
)

@patch("huggingface_hub.commands.upload.HfApi.repo_info")
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
def test_upload_file_no_revision_mock(self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock) -> None:
with SoftTemporaryDirectory() as cache_dir:
file_path = Path(cache_dir) / "file.txt"
file_path.write_text("content")
cmd = UploadCommand(self.parser.parse_args(["upload", "my-model", str(file_path), "logs/file.txt"]))
cmd.run()
# Revision not specified => no need to check
repo_info_mock.assert_not_called()

@patch("huggingface_hub.commands.upload.HfApi.create_branch")
@patch("huggingface_hub.commands.upload.HfApi.repo_info")
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
def test_upload_file_with_revision_mock(
self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock, create_branch_mock: Mock
) -> None:
repo_info_mock.side_effect = RevisionNotFoundError("revision not found")

with SoftTemporaryDirectory() as cache_dir:
file_path = Path(cache_dir) / "file.txt"
file_path.write_text("content")
cmd = UploadCommand(
self.parser.parse_args(
["upload", "my-model", str(file_path), "logs/file.txt", "--revision", "my-branch"]
)
)
cmd.run()

# Revision specified => check that it exists
repo_info_mock.assert_called_once_with(
repo_id=create_mock.return_value.repo_id, repo_type="model", revision="my-branch"
)

# Revision does not exist => create it
create_branch_mock.assert_called_once_with(
repo_id=create_mock.return_value.repo_id, repo_type="model", branch="my-branch", exist_ok=True
)

@patch("huggingface_hub.commands.upload.HfApi.repo_info")
@patch("huggingface_hub.commands.upload.HfApi.upload_file")
@patch("huggingface_hub.commands.upload.HfApi.create_repo")
def test_upload_file_revision_and_create_pr_mock(
self, create_mock: Mock, upload_mock: Mock, repo_info_mock: Mock
) -> None:
with SoftTemporaryDirectory() as cache_dir:
file_path = Path(cache_dir) / "file.txt"
file_path.write_text("content")
cmd = UploadCommand(
self.parser.parse_args(
["upload", "my-model", str(file_path), "logs/file.txt", "--revision", "my-branch", "--create-pr"]
)
)
cmd.run()
# Revision specified but --create-pr => no need to check
repo_info_mock.assert_not_called()

@patch("huggingface_hub.commands.upload.HfApi.create_repo")
def test_upload_missing_path(self, create_mock: Mock) -> None:
cmd = UploadCommand(self.parser.parse_args(["upload", "my-model", "/path/to/missing_file", "logs/file.txt"]))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,12 +1563,12 @@ def test_create_branch_from_revision(self, repo_url: RepoUrl) -> None:
"""Test `create_branch` from a different revision than main HEAD."""
# Create commit and remember initial/latest commit
initial_commit = self._api.model_info(repo_url.repo_id).sha
self._api.create_commit(
commit = self._api.create_commit(
repo_url.repo_id,
operations=[CommitOperationAdd(path_in_repo="app.py", path_or_fileobj=b"content")],
commit_message="test commit",
)
latest_commit = self._api.model_info(repo_url.repo_id).sha
latest_commit = commit.oid

# Create branches
self._api.create_branch(repo_url.repo_id, branch="from-head")
Expand Down