@@ -54,16 +54,28 @@ def test_run_batch(run_name, device_config_name):
54
54
circuit2 = cirq .Circuit (cirq .Y (a ))
55
55
params1 = [cirq .ParamResolver ({'t' : 1 })]
56
56
params2 = [cirq .ParamResolver ({'t' : 2 })]
57
- circuits = [circuit1 , circuit2 ]
58
- params_list = [params1 , params2 ]
59
- sampler .run_batch (circuits , params_list , 5 )
60
- processor .run_batch_async .assert_called_with (
61
- params_list = params_list ,
62
- programs = circuits ,
63
- repetitions = 5 ,
64
- run_name = run_name ,
65
- device_config_name = device_config_name ,
66
- )
57
+
58
+ sampler .run_batch ([circuit1 , circuit2 ], [params1 , params2 ], 5 )
59
+
60
+ expected_calls = [
61
+ mock .call (
62
+ program = circuit1 ,
63
+ params = params1 ,
64
+ repetitions = 5 ,
65
+ run_name = run_name ,
66
+ device_config_name = device_config_name ,
67
+ ),
68
+ mock .call ().results_async (),
69
+ mock .call (
70
+ program = circuit2 ,
71
+ params = params2 ,
72
+ repetitions = 5 ,
73
+ run_name = run_name ,
74
+ device_config_name = device_config_name ,
75
+ ),
76
+ mock .call ().results_async (),
77
+ ]
78
+ processor .run_sweep_async .assert_has_calls (expected_calls )
67
79
68
80
69
81
@pytest .mark .parametrize (
@@ -79,16 +91,28 @@ def test_run_batch_identical_repetitions(run_name, device_config_name):
79
91
circuit2 = cirq .Circuit (cirq .Y (a ))
80
92
params1 = [cirq .ParamResolver ({'t' : 1 })]
81
93
params2 = [cirq .ParamResolver ({'t' : 2 })]
82
- circuits = [circuit1 , circuit2 ]
83
- params_list = [params1 , params2 ]
84
- sampler .run_batch (circuits , params_list , [5 , 5 ])
85
- processor .run_batch_async .assert_called_with (
86
- params_list = params_list ,
87
- programs = circuits ,
88
- repetitions = 5 ,
89
- run_name = run_name ,
90
- device_config_name = device_config_name ,
91
- )
94
+
95
+ sampler .run_batch ([circuit1 , circuit2 ], [params1 , params2 ], [5 , 5 ])
96
+
97
+ expected_calls = [
98
+ mock .call (
99
+ program = circuit1 ,
100
+ params = params1 ,
101
+ repetitions = 5 ,
102
+ run_name = run_name ,
103
+ device_config_name = device_config_name ,
104
+ ),
105
+ mock .call ().results_async (),
106
+ mock .call (
107
+ program = circuit2 ,
108
+ params = params2 ,
109
+ repetitions = 5 ,
110
+ run_name = run_name ,
111
+ device_config_name = device_config_name ,
112
+ ),
113
+ mock .call ().results_async (),
114
+ ]
115
+ processor .run_sweep_async .assert_has_calls (expected_calls )
92
116
93
117
94
118
def test_run_batch_bad_number_of_repetitions ():
0 commit comments