Skip to content

Commit 65d783e

Browse files
authored
[cirqflow] Convenience method for loading results (#4720)
Add cg.ExecutableGroupResultFilesystemRecord.from_json(run_id)
1 parent 5da6335 commit 65d783e

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

cirq-google/cirq_google/workflow/io.py

+17
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ class ExecutableGroupResultFilesystemRecord:
4444

4545
run_id: str
4646

47+
@classmethod
48+
def from_json(
49+
cls, *, run_id: str, base_data_dir: str = "."
50+
) -> 'ExecutableGroupResultFilesystemRecord':
51+
fn = f'{base_data_dir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
52+
egr_record = cirq.read_json_gzip(fn)
53+
if not isinstance(egr_record, cls):
54+
raise ValueError(
55+
f"The file located at {fn} is not an `ExecutableGroupFilesystemRecord`."
56+
)
57+
if egr_record.run_id != run_id:
58+
raise ValueError(
59+
f"The loaded run_id {run_id} does not match the provided run_id {run_id}"
60+
)
61+
62+
return egr_record
63+
4764
def load(self, *, base_data_dir: str = ".") -> 'cg.ExecutableGroupResult':
4865
"""Using the filename references in this dataclass, load a `cg.ExecutableGroupResult`
4966
from its constituent parts.

cirq-google/cirq_google/workflow/io_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
16+
import pytest
1417

1518
import cirq
1619
import cirq_google as cg
@@ -47,6 +50,46 @@ def test_egr_filesystem_record_repr():
4750
cg_assert_equivalent_repr(egr_fs_record)
4851

4952

53+
def test_egr_filesystem_record_from_json(tmpdir):
54+
run_id = 'my-run-id'
55+
egr_fs_record = cg.ExecutableGroupResultFilesystemRecord(
56+
runtime_configuration_path='RuntimeConfiguration.json.gz',
57+
shared_runtime_info_path='SharedRuntimeInfo.jzon.gz',
58+
executable_result_paths=[
59+
'ExecutableResult.1.json.gz',
60+
'ExecutableResult.2.json.gz',
61+
],
62+
run_id=run_id,
63+
)
64+
65+
# Test 1: normal
66+
os.makedirs(f'{tmpdir}/{run_id}')
67+
cirq.to_json_gzip(
68+
egr_fs_record, f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
69+
)
70+
egr_fs_record2 = cg.ExecutableGroupResultFilesystemRecord.from_json(
71+
run_id=run_id, base_data_dir=tmpdir
72+
)
73+
assert egr_fs_record == egr_fs_record2
74+
75+
# Test 2: bad object type
76+
cirq.to_json_gzip(
77+
cirq.Circuit(), f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
78+
)
79+
with pytest.raises(ValueError, match=r'.*not an `ExecutableGroupFilesystemRecord`.'):
80+
cg.ExecutableGroupResultFilesystemRecord.from_json(run_id=run_id, base_data_dir=tmpdir)
81+
82+
# Test 3: Mismatched run id
83+
os.makedirs(f'{tmpdir}/questionable_run_id')
84+
cirq.to_json_gzip(
85+
egr_fs_record, f'{tmpdir}/questionable_run_id/ExecutableGroupResultFilesystemRecord.json.gz'
86+
)
87+
with pytest.raises(ValueError, match=r'.*does not match the provided run_id'):
88+
cg.ExecutableGroupResultFilesystemRecord.from_json(
89+
run_id='questionable_run_id', base_data_dir=tmpdir
90+
)
91+
92+
5093
def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers):
5194
assert patch_cirq_default_resolvers
5295
run_id = 'asdf'
@@ -56,11 +99,13 @@ def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers):
5699
shared_rt_info = cg.SharedRuntimeInfo(run_id=run_id)
57100
fs_saver.initialize(rt_config, shared_rt_info=shared_rt_info)
58101

102+
# Test 1: assert fs_saver.initialize() has worked.
59103
rt_config2 = cirq.read_json_gzip(f'{tmpdir}/{run_id}/QuantumRuntimeConfiguration.json.gz')
60104
shared_rt_info2 = cirq.read_json_gzip(f'{tmpdir}/{run_id}/SharedRuntimeInfo.json.gz')
61105
assert rt_config == rt_config2
62106
assert shared_rt_info == shared_rt_info2
63107

108+
# Test 2: assert `consume_result()` works.
64109
# you shouldn't actually mutate run_id in the shared runtime info, but we want to test
65110
# updating the shared rt info object:
66111
shared_rt_info.run_id = 'updated_run_id'
@@ -76,6 +121,7 @@ def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers):
76121
assert shared_rt_info == shared_rt_info3
77122
assert exe_result == exe_result3
78123

124+
# Test 3: assert loading egr_record works.
79125
egr_record: cg.ExecutableGroupResultFilesystemRecord = cirq.read_json_gzip(
80126
f'{fs_saver.data_dir}/ExecutableGroupResultFilesystemRecord.json.gz'
81127
)

cirq-google/cirq_google/workflow/quantum_runtime_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,14 @@ def test_execute(tmpdir, run_id_in, patch_cirq_default_resolvers):
182182
f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
183183
)
184184
exegroup_result: cg.ExecutableGroupResult = egr_record.load(base_data_dir=tmpdir)
185+
helper_loaded_result = cg.ExecutableGroupResultFilesystemRecord.from_json(
186+
run_id=run_id, base_data_dir=tmpdir
187+
).load(base_data_dir=tmpdir)
185188

186189
# TODO(gh-4699): Don't null-out device once it's serializable.
187190
assert isinstance(returned_exegroup_result.shared_runtime_info.device, cg.SerializableDevice)
188191
returned_exegroup_result.shared_runtime_info.device = None
189192

190193
assert returned_exegroup_result == exegroup_result
191194
assert manual_exegroup_result == exegroup_result
195+
assert helper_loaded_result == exegroup_result

0 commit comments

Comments
 (0)