Skip to content

Commit 97a1818

Browse files
Revert "Merge pull request #772 from Sinestro38/fix_license"
This reverts commit 140277d, reversing changes made to 7131f9e.
1 parent a413194 commit 97a1818

6 files changed

+653
-2
lines changed

scripts/ci_install.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16-
wget https://github.com/bazelbuild/bazel/releases/download/5.1.0/bazel_5.1.0-linux-x86_64.deb
17-
sudo dpkg -i bazel_5.1.0-linux-x86_64.deb
16+
wget https://github.com/bazelbuild/bazel/releases/download/5.3.0/bazel_5.3.0-linux-x86_64.deb
17+
sudo dpkg -i bazel_5.3.0-linux-x86_64.deb
1818
pip install --upgrade pip setuptools wheel
1919
pip install -r requirements.txt
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/* Copyright 2020 The TensorFlow Quantum Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.
11+
==============================================================================*/
12+
13+
#include <memory>
14+
#include <vector>
15+
16+
#include <chrono>
17+
18+
#include "../qsim/lib/circuit.h"
19+
#include "../qsim/lib/gate_appl.h"
20+
#include "../qsim/lib/gates_cirq.h"
21+
#include "../qsim/lib/gates_qsim.h"
22+
#include "../qsim/lib/seqfor.h"
23+
#include "../qsim/lib/simulator_cuda.h"
24+
#include "../qsim/lib/statespace_cuda.h"
25+
#include "tensorflow/core/framework/op_kernel.h"
26+
#include "tensorflow/core/framework/shape_inference.h"
27+
#include "tensorflow/core/framework/tensor_shape.h"
28+
#include "tensorflow/core/lib/core/error_codes.pb.h"
29+
#include "tensorflow/core/lib/core/status.h"
30+
#include "tensorflow/core/lib/core/threadpool.h"
31+
#include "tensorflow/core/platform/mutex.h"
32+
#include "tensorflow_quantum/core/ops/parse_context.h"
33+
#include "tensorflow_quantum/core/proto/pauli_sum.pb.h"
34+
#include "tensorflow_quantum/core/proto/program.pb.h"
35+
#include "tensorflow_quantum/core/src/util_qsim.h"
36+
37+
namespace tfq {
38+
39+
using ::tensorflow::Status;
40+
using ::tfq::proto::PauliSum;
41+
using ::tfq::proto::Program;
42+
43+
typedef qsim::Cirq::GateCirq<float> QsimGate;
44+
typedef qsim::Circuit<QsimGate> QsimCircuit;
45+
46+
47+
class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
48+
public:
49+
explicit TfqSimulateExpectationOpCuda(tensorflow::OpKernelConstruction* context)
50+
: OpKernel(context) {}
51+
52+
void Compute(tensorflow::OpKernelContext* context) override {
53+
// TODO (mbbrough): add more dimension checks for other inputs here.
54+
const int num_inputs = context->num_inputs();
55+
OP_REQUIRES(context, num_inputs == 4,
56+
tensorflow::errors::InvalidArgument(absl::StrCat(
57+
"Expected 4 inputs, got ", num_inputs, " inputs.")));
58+
59+
// Create the output Tensor.
60+
const int output_dim_batch_size = context->input(0).dim_size(0);
61+
const int output_dim_op_size = context->input(3).dim_size(1);
62+
tensorflow::TensorShape output_shape;
63+
output_shape.AddDim(output_dim_batch_size);
64+
output_shape.AddDim(output_dim_op_size);
65+
66+
tensorflow::Tensor* output = nullptr;
67+
tensorflow::AllocatorAttributes alloc_attr;
68+
alloc_attr.set_on_host(true);
69+
alloc_attr.set_gpu_compatible(true);
70+
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output,
71+
alloc_attr));
72+
auto output_tensor = output->matrix<float>();
73+
// Parse program protos.
74+
std::vector<Program> programs;
75+
std::vector<int> num_qubits;
76+
std::vector<std::vector<PauliSum>> pauli_sums;
77+
OP_REQUIRES_OK(context, GetProgramsAndNumQubits(context, &programs,
78+
&num_qubits, &pauli_sums));
79+
80+
std::vector<SymbolMap> maps;
81+
OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps));
82+
83+
OP_REQUIRES(context, programs.size() == maps.size(),
84+
tensorflow::errors::InvalidArgument(absl::StrCat(
85+
"Number of circuits and symbol_values do not match. Got ",
86+
programs.size(), " circuits and ", maps.size(),
87+
" symbol values.")));
88+
89+
// Construct qsim circuits.
90+
std::vector<QsimCircuit> qsim_circuits(programs.size(), QsimCircuit());
91+
std::vector<std::vector<qsim::GateFused<QsimGate>>> fused_circuits(
92+
programs.size(), std::vector<qsim::GateFused<QsimGate>>({}));
93+
94+
Status parse_status = Status();
95+
auto p_lock = tensorflow::mutex();
96+
auto construct_f = [&](int start, int end) {
97+
for (int i = start; i < end; i++) {
98+
Status local =
99+
QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
100+
&qsim_circuits[i], &fused_circuits[i]);
101+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
102+
}
103+
};
104+
105+
const int num_cycles = 1000;
106+
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
107+
programs.size(), num_cycles, construct_f);
108+
OP_REQUIRES_OK(context, parse_status);
109+
110+
int max_num_qubits = 0;
111+
for (const int num : num_qubits) {
112+
max_num_qubits = std::max(max_num_qubits, num);
113+
}
114+
ComputeLarge(num_qubits, fused_circuits, pauli_sums, context,
115+
&output_tensor);
116+
}
117+
118+
private:
119+
int num_threads_in_sim_;
120+
int block_count_;
121+
122+
// Define the GPU implementation that launches the CUDA kernel.
123+
void ComputeLarge(
124+
const std::vector<int>& num_qubits,
125+
const std::vector<std::vector<qsim::GateFused<QsimGate>>>& fused_circuits,
126+
const std::vector<std::vector<PauliSum>>& pauli_sums,
127+
tensorflow::OpKernelContext* context,
128+
tensorflow::TTypes<float, 1>::Matrix* output_tensor) {
129+
// Instantiate qsim objects.
130+
using Simulator = qsim::SimulatorCUDA<float>;
131+
using StateSpace = Simulator::StateSpace;
132+
// Begin simulation with default parameters.
133+
int largest_nq = 1;
134+
Simulator sim = Simulator();
135+
StateSpace ss = StateSpace(StateSpace::Parameter());
136+
auto sv = ss.Create(largest_nq);
137+
auto scratch = ss.Create(largest_nq);
138+
139+
// Simulate programs one by one. Parallelizing over state vectors
140+
// we no longer parallelize over circuits. Each time we encounter a
141+
// a larger circuit we will grow the Statevector as necessary.
142+
for (int i = 0; i < fused_circuits.size(); i++) {
143+
int nq = num_qubits[i];
144+
145+
if (nq > largest_nq) {
146+
// need to switch to larger statespace.
147+
largest_nq = nq;
148+
sv = ss.Create(largest_nq);
149+
scratch = ss.Create(largest_nq);
150+
}
151+
// TODO: add heuristic here so that we do not always recompute
152+
// the state if there is a possibility that circuit[i] and
153+
// circuit[i + 1] produce the same state.
154+
ss.SetStateZero(sv);
155+
for (int j = 0; j < fused_circuits[i].size(); j++) {
156+
qsim::ApplyFusedGate(sim, fused_circuits[i][j], sv);
157+
}
158+
for (int j = 0; j < pauli_sums[i].size(); j++) {
159+
// (#679) Just ignore empty program
160+
if (fused_circuits[i].size() == 0) {
161+
(*output_tensor)(i, j) = -2.0;
162+
continue;
163+
}
164+
float exp_v = 0.0;
165+
OP_REQUIRES_OK(context,
166+
ComputeExpectationQsim(pauli_sums[i][j], sim, ss, sv,
167+
scratch, &exp_v));
168+
(*output_tensor)(i, j) = exp_v;
169+
}
170+
}
171+
}
172+
173+
};
174+
175+
REGISTER_KERNEL_BUILDER(
176+
Name("TfqSimulateExpectationCuda").Device(tensorflow::DEVICE_CPU),
177+
TfqSimulateExpectationOpCuda);
178+
179+
REGISTER_OP("TfqSimulateExpectationCuda")
180+
.Input("programs: string")
181+
.Input("symbol_names: string")
182+
.Input("symbol_values: float")
183+
.Input("pauli_sums: string")
184+
.Output("expectations: float")
185+
.SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
186+
tensorflow::shape_inference::ShapeHandle programs_shape;
187+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &programs_shape));
188+
189+
tensorflow::shape_inference::ShapeHandle symbol_names_shape;
190+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &symbol_names_shape));
191+
192+
tensorflow::shape_inference::ShapeHandle symbol_values_shape;
193+
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &symbol_values_shape));
194+
195+
tensorflow::shape_inference::ShapeHandle pauli_sums_shape;
196+
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &pauli_sums_shape));
197+
198+
tensorflow::shape_inference::DimensionHandle output_rows =
199+
c->Dim(programs_shape, 0);
200+
tensorflow::shape_inference::DimensionHandle output_cols =
201+
c->Dim(pauli_sums_shape, 1);
202+
c->set_output(0, c->Matrix(output_rows, output_cols));
203+
204+
return ::tensorflow::Status();
205+
});
206+
207+
} // namespace tfq

0 commit comments

Comments
 (0)