@@ -426,3 +426,84 @@ def prediction_with_status(
426
426
427
427
assert output1 .read () == b"Hello,"
428
428
assert output2 .read () == b" world!"
429
+
430
+
431
+ @pytest .mark .asyncio
432
+ async def test_run_with_file_output_data_uri (mock_replicate_api_token ):
433
+ def prediction_with_status (
434
+ status : str , output : str | list [str ] | None = None
435
+ ) -> dict :
436
+ return {
437
+ "id" : "p1" ,
438
+ "model" : "test/example" ,
439
+ "version" : "v1" ,
440
+ "urls" : {
441
+ "get" : "https://api.replicate.com/v1/predictions/p1" ,
442
+ "cancel" : "https://api.replicate.com/v1/predictions/p1/cancel" ,
443
+ },
444
+ "created_at" : "2023-10-05T12:00:00.000000Z" ,
445
+ "source" : "api" ,
446
+ "status" : status ,
447
+ "input" : {"text" : "world" },
448
+ "output" : output ,
449
+ "error" : "OOM" if status == "failed" else None ,
450
+ "logs" : "" ,
451
+ }
452
+
453
+ router = respx .Router (base_url = "https://api.replicate.com/v1" )
454
+ router .route (method = "POST" , path = "/predictions" ).mock (
455
+ return_value = httpx .Response (
456
+ 201 ,
457
+ json = prediction_with_status ("processing" ),
458
+ )
459
+ )
460
+ router .route (method = "GET" , path = "/predictions/p1" ).mock (
461
+ return_value = httpx .Response (
462
+ 200 ,
463
+ json = prediction_with_status (
464
+ "succeeded" ,
465
+ "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" ,
466
+ ),
467
+ )
468
+ )
469
+ router .route (
470
+ method = "GET" ,
471
+ path = "/models/test/example/versions/v1" ,
472
+ ).mock (
473
+ return_value = httpx .Response (
474
+ 201 ,
475
+ json = {
476
+ "id" : "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1" ,
477
+ "created_at" : "2024-07-18T00:35:56.210272Z" ,
478
+ "cog_version" : "0.9.10" ,
479
+ "openapi_schema" : {
480
+ "openapi" : "3.0.2" ,
481
+ },
482
+ },
483
+ )
484
+ )
485
+
486
+ client = Client (
487
+ api_token = "test-token" , transport = httpx .MockTransport (router .handler )
488
+ )
489
+ client .poll_interval = 0.001
490
+
491
+ output = cast (
492
+ FileOutput ,
493
+ client .run (
494
+ "test/example:v1" ,
495
+ input = {
496
+ "text" : "Hello, world!" ,
497
+ },
498
+ use_file_output = True ,
499
+ ),
500
+ )
501
+
502
+ assert output .url == "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ=="
503
+ assert output .read () == b"Hello, world!"
504
+ for chunk in output :
505
+ assert chunk == b"Hello, world!"
506
+
507
+ assert await output .aread () == b"Hello, world!"
508
+ async for chunk in output :
509
+ assert chunk == b"Hello, world!"
0 commit comments