|
| 1 | +import json |
| 2 | +import io |
1 | 3 | import asyncio
|
2 | 4 | import sys
|
3 |
| -from typing import AsyncIterator, Iterator, Optional, cast |
| 5 | +from typing import AsyncIterator, Iterator, Optional, cast, Type |
4 | 6 |
|
5 | 7 | import httpx
|
6 | 8 | import pytest
|
|
11 | 13 | from replicate.exceptions import ModelError, ReplicateError
|
12 | 14 | from replicate.helpers import FileOutput
|
13 | 15 |
|
| 16 | +import email |
| 17 | +from email.message import EmailMessage |
| 18 | +from email.parser import BytesParser |
| 19 | +from email.policy import HTTP |
| 20 | + |
14 | 21 |
|
15 | 22 | @pytest.mark.vcr("run.yaml")
|
16 | 23 | @pytest.mark.asyncio
|
@@ -581,6 +588,130 @@ async def test_run_with_model_error(mock_replicate_api_token):
|
581 | 588 | assert excinfo.value.prediction.status == "failed"
|
582 | 589 |
|
583 | 590 |
|
| 591 | +@pytest.mark.asyncio |
| 592 | +@pytest.mark.parametrize("async_flag", [True, False]) |
| 593 | +async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token): |
| 594 | + router = respx.Router(base_url="https://api.replicate.com/v1") |
| 595 | + mock_predictions_create = router.route(method="POST", path="/predictions").mock( |
| 596 | + return_value=httpx.Response( |
| 597 | + 201, |
| 598 | + json=_prediction_with_status("processing"), |
| 599 | + ) |
| 600 | + ) |
| 601 | + router.route( |
| 602 | + method="GET", |
| 603 | + path="/models/test/example/versions/v1", |
| 604 | + ).mock( |
| 605 | + return_value=httpx.Response( |
| 606 | + 200, |
| 607 | + json=_version_with_schema(), |
| 608 | + ) |
| 609 | + ) |
| 610 | + mock_files_create = router.route( |
| 611 | + method="POST", |
| 612 | + path="/files", |
| 613 | + ).mock( |
| 614 | + return_value=httpx.Response( |
| 615 | + 200, |
| 616 | + json={ |
| 617 | + "id": "file1", |
| 618 | + "name": "file.png", |
| 619 | + "content_type": "image/png", |
| 620 | + "size": 10, |
| 621 | + "etag": "123", |
| 622 | + "checksums": {}, |
| 623 | + "metadata": {}, |
| 624 | + "created_at": "", |
| 625 | + "expires_at": "", |
| 626 | + "urls": {"get": "https://api.replicate.com/files/file.txt"}, |
| 627 | + }, |
| 628 | + ) |
| 629 | + ) |
| 630 | + router.route(host="api.replicate.com").pass_through() |
| 631 | + |
| 632 | + client = Client( |
| 633 | + api_token="test-token", transport=httpx.MockTransport(router.handler) |
| 634 | + ) |
| 635 | + if async_flag: |
| 636 | + await client.async_run( |
| 637 | + "test/example:v1", |
| 638 | + input={"file": io.BytesIO(initial_bytes=b"hello world")}, |
| 639 | + ) |
| 640 | + else: |
| 641 | + client.run( |
| 642 | + "test/example:v1", |
| 643 | + input={"file": io.BytesIO(initial_bytes=b"hello world")}, |
| 644 | + ) |
| 645 | + |
| 646 | + assert mock_predictions_create.called |
| 647 | + prediction_payload = json.loads(mock_predictions_create.calls[0].request.content) |
| 648 | + assert ( |
| 649 | + prediction_payload.get("input", {}).get("file") |
| 650 | + == "https://api.replicate.com/files/file.txt" |
| 651 | + ) |
| 652 | + |
| 653 | + # Validate the Files API request |
| 654 | + req = mock_files_create.calls[0].request |
| 655 | + body = req.content |
| 656 | + content_type = req.headers["Content-Type"] |
| 657 | + |
| 658 | + # Parse the multipart data |
| 659 | + parser = BytesParser(EmailMessage, policy=HTTP) |
| 660 | + headers = f"Content-Type: {content_type}\n\n".encode("utf-8") |
| 661 | + parsed_message_generator = parser.parsebytes(headers + body).walk() |
| 662 | + next(parsed_message_generator) # wrapper |
| 663 | + input_file = next(parsed_message_generator) |
| 664 | + assert mock_files_create.called |
| 665 | + assert input_file.get_content() == b"hello world" |
| 666 | + assert input_file.get_content_type() == "application/octet-stream" |
| 667 | + |
| 668 | + |
| 669 | +@pytest.mark.asyncio |
| 670 | +@pytest.mark.parametrize("async_flag", [True, False]) |
| 671 | +async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token): |
| 672 | + router = respx.Router(base_url="https://api.replicate.com/v1") |
| 673 | + mock_predictions_create = router.route(method="POST", path="/predictions").mock( |
| 674 | + return_value=httpx.Response( |
| 675 | + 201, |
| 676 | + json=_prediction_with_status("processing"), |
| 677 | + ) |
| 678 | + ) |
| 679 | + router.route( |
| 680 | + method="GET", |
| 681 | + path="/models/test/example/versions/v1", |
| 682 | + ).mock( |
| 683 | + return_value=httpx.Response( |
| 684 | + 200, |
| 685 | + json=_version_with_schema(), |
| 686 | + ) |
| 687 | + ) |
| 688 | + router.route(host="api.replicate.com").pass_through() |
| 689 | + |
| 690 | + client = Client( |
| 691 | + api_token="test-token", transport=httpx.MockTransport(router.handler) |
| 692 | + ) |
| 693 | + |
| 694 | + if async_flag: |
| 695 | + await client.async_run( |
| 696 | + "test/example:v1", |
| 697 | + input={"file": io.BytesIO(initial_bytes=b"hello world")}, |
| 698 | + file_encoding_strategy="base64", |
| 699 | + ) |
| 700 | + else: |
| 701 | + client.run( |
| 702 | + "test/example:v1", |
| 703 | + input={"file": io.BytesIO(initial_bytes=b"hello world")}, |
| 704 | + file_encoding_strategy="base64", |
| 705 | + ) |
| 706 | + |
| 707 | + assert mock_predictions_create.called |
| 708 | + prediction_payload = json.loads(mock_predictions_create.calls[0].request.content) |
| 709 | + assert ( |
| 710 | + prediction_payload.get("input", {}).get("file") |
| 711 | + == "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=" |
| 712 | + ) |
| 713 | + |
| 714 | + |
584 | 715 | @pytest.mark.asyncio
|
585 | 716 | async def test_run_with_file_output(mock_replicate_api_token):
|
586 | 717 | router = respx.Router(base_url="https://api.replicate.com/v1")
|
|
0 commit comments