Skip to content

Commit 5cb7fa1

Browse files
committed
remove computeSmall from cuda and cuquantum kernels
1 parent f81e537 commit 5cb7fa1

File tree

2 files changed

+9
-154
lines changed

2 files changed

+9
-154
lines changed

tensorflow_quantum/core/ops/tfq_simulate_expectation_op_cuda.cu.cc

+4-77
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,10 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
111111
for (const int num : num_qubits) {
112112
max_num_qubits = std::max(max_num_qubits, num);
113113
}
114-
if (max_num_qubits >= 26 || programs.size() == 1) {
115-
ComputeLarge(num_qubits, fused_circuits, pauli_sums, context,
116-
&output_tensor);
117-
} else {
118-
ComputeSmall(num_qubits, max_num_qubits, fused_circuits, pauli_sums,
119-
context, &output_tensor);
120-
}
114+
115+
ComputeLarge(num_qubits, fused_circuits, pauli_sums, context,
116+
&output_tensor);
117+
121118
}
122119

123120
private:
@@ -175,76 +172,6 @@ class TfqSimulateExpectationOpCuda : public tensorflow::OpKernel {
175172
}
176173
}
177174

178-
void ComputeSmall(
179-
const std::vector<int>& num_qubits, const int max_num_qubits,
180-
const std::vector<std::vector<qsim::GateFused<QsimGate>>>& fused_circuits,
181-
const std::vector<std::vector<PauliSum>>& pauli_sums,
182-
tensorflow::OpKernelContext* context,
183-
tensorflow::TTypes<float, 1>::Matrix* output_tensor) {
184-
using Simulator = qsim::SimulatorCUDA<float>;
185-
using StateSpace = Simulator::StateSpace;
186-
187-
StateSpace::Parameter param_default;
188-
const int output_dim_op_size = output_tensor->dimension(1);
189-
190-
Status compute_status = Status();
191-
auto c_lock = tensorflow::mutex();
192-
auto DoWork = [&](int start, int end) {
193-
int old_batch_index = -2;
194-
int cur_batch_index = -1;
195-
int largest_nq = 1;
196-
int cur_op_index;
197-
198-
// Begin simulation.
199-
auto sim = Simulator();
200-
auto ss = StateSpace(param_default);
201-
auto sv = ss.Create(largest_nq);
202-
auto scratch = ss.Create(largest_nq);
203-
for (int i = start; i < end; i++) {
204-
cur_batch_index = i / output_dim_op_size;
205-
cur_op_index = i % output_dim_op_size;
206-
207-
const int nq = num_qubits[cur_batch_index];
208-
209-
// (#679) Just ignore empty program
210-
if (fused_circuits[cur_batch_index].size() == 0) {
211-
(*output_tensor)(cur_batch_index, cur_op_index) = -2.0;
212-
continue;
213-
}
214-
215-
if (cur_batch_index != old_batch_index) {
216-
// We've run into a new state vector we must compute.
217-
// Only compute a new state vector when we have to.
218-
if (nq > largest_nq) {
219-
largest_nq = nq;
220-
sv = ss.Create(largest_nq);
221-
scratch = ss.Create(largest_nq);
222-
}
223-
// no need to update scratch_state since ComputeExpectation
224-
// will take care of things for us.
225-
ss.SetStateZero(sv);
226-
for (int j = 0; j < fused_circuits[cur_batch_index].size(); j++) {
227-
qsim::ApplyFusedGate(sim, fused_circuits[cur_batch_index][j], sv);
228-
}
229-
}
230-
231-
float exp_v = 0.0;
232-
NESTED_FN_STATUS_SYNC(
233-
compute_status,
234-
ComputeExpectationQsim(pauli_sums[cur_batch_index][cur_op_index],
235-
sim, ss, sv, scratch, &exp_v),
236-
c_lock);
237-
(*output_tensor)(cur_batch_index, cur_op_index) = exp_v;
238-
old_batch_index = cur_batch_index;
239-
}
240-
};
241-
242-
const int64_t num_cycles =
243-
200 * (int64_t(1) << static_cast<int64_t>(max_num_qubits));
244-
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
245-
fused_circuits.size() * output_dim_op_size, num_cycles, DoWork);
246-
OP_REQUIRES_OK(context, compute_status);
247-
}
248175
};
249176

250177
REGISTER_KERNEL_BUILDER(

tensorflow_quantum/core/ops/tfq_simulate_expectation_op_cuquantum.cu.cc

+5-77
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TfqSimulateExpectationOpCuQuantum : public tensorflow::OpKernel {
5151
: OpKernel(context) {}
5252

5353
void Compute(tensorflow::OpKernelContext* context) override {
54+
5455
// TODO (mbbrough): add more dimension checks for other inputs here.
5556
const int num_inputs = context->num_inputs();
5657
OP_REQUIRES(context, num_inputs == 4,
@@ -116,13 +117,10 @@ class TfqSimulateExpectationOpCuQuantum : public tensorflow::OpKernel {
116117
// create handles for simulator
117118
cublasCreate(&cublas_handle_);
118119
custatevecCreate(&custatevec_handle_);
119-
if (max_num_qubits >= 26 || programs.size() == 1) {
120-
ComputeLarge(num_qubits, fused_circuits, pauli_sums, context,
121-
&output_tensor); // HOW TO manage extraWorkspace size?
122-
} else {
123-
ComputeSmall(num_qubits, max_num_qubits, fused_circuits, pauli_sums,
124-
context, &output_tensor);
125-
}
120+
121+
ComputeLarge(num_qubits, fused_circuits, pauli_sums, context,
122+
&output_tensor); // HOW TO manage extraWorkspace size?
123+
126124
// destroy handles in sync with simulator lifetime
127125
cublasDestroy(cublas_handle_);
128126
custatevecDestroy(custatevec_handle_);
@@ -186,76 +184,6 @@ class TfqSimulateExpectationOpCuQuantum : public tensorflow::OpKernel {
186184
}
187185
}
188186
}
189-
190-
void ComputeSmall(
191-
const std::vector<int>& num_qubits, const int max_num_qubits,
192-
const std::vector<std::vector<qsim::GateFused<QsimGate>>>& fused_circuits,
193-
const std::vector<std::vector<PauliSum>>& pauli_sums,
194-
tensorflow::OpKernelContext* context,
195-
tensorflow::TTypes<float, 1>::Matrix* output_tensor) {
196-
using Simulator = qsim::SimulatorCuStateVec<float>;
197-
using StateSpace = Simulator::StateSpace;
198-
199-
const int output_dim_op_size = output_tensor->dimension(1);
200-
201-
Status compute_status = Status::OK();
202-
auto c_lock = tensorflow::mutex();
203-
auto DoWork = [&](int start, int end) {
204-
int old_batch_index = -2;
205-
int cur_batch_index = -1;
206-
int largest_nq = 1;
207-
int cur_op_index;
208-
209-
// Launch custatevec, begin simulation.
210-
auto sim = Simulator(cublas_handle_, custatevec_handle_);
211-
auto ss = StateSpace(cublas_handle_, custatevec_handle_);
212-
auto sv = ss.Create(largest_nq);
213-
auto scratch = ss.Create(largest_nq);
214-
for (int i = start; i < end; i++) {
215-
cur_batch_index = i / output_dim_op_size;
216-
cur_op_index = i % output_dim_op_size;
217-
218-
const int nq = num_qubits[cur_batch_index];
219-
220-
// (#679) Just ignore empty program
221-
if (fused_circuits[cur_batch_index].size() == 0) {
222-
(*output_tensor)(cur_batch_index, cur_op_index) = -2.0;
223-
continue;
224-
}
225-
226-
if (cur_batch_index != old_batch_index) {
227-
// We've run into a new state vector we must compute.
228-
// Only compute a new state vector when we have to.
229-
if (nq > largest_nq) {
230-
largest_nq = nq;
231-
sv = ss.Create(largest_nq);
232-
scratch = ss.Create(largest_nq);
233-
}
234-
// no need to update scratch_state since ComputeExpectation
235-
// will take care of things for us.
236-
ss.SetStateZero(sv);
237-
for (int j = 0; j < fused_circuits[cur_batch_index].size(); j++) {
238-
qsim::ApplyFusedGate(sim, fused_circuits[cur_batch_index][j], sv);
239-
}
240-
}
241-
242-
float exp_v = 0.0;
243-
NESTED_FN_STATUS_SYNC(
244-
compute_status,
245-
ComputeExpectationQsim(pauli_sums[cur_batch_index][cur_op_index],
246-
sim, ss, sv, scratch, &exp_v),
247-
c_lock);
248-
(*output_tensor)(cur_batch_index, cur_op_index) = exp_v;
249-
old_batch_index = cur_batch_index;
250-
}
251-
};
252-
253-
const int64_t num_cycles =
254-
200 * (int64_t(1) << static_cast<int64_t>(max_num_qubits));
255-
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
256-
fused_circuits.size() * output_dim_op_size, num_cycles, DoWork);
257-
OP_REQUIRES_OK(context, compute_status);
258-
}
259187
};
260188

261189
REGISTER_KERNEL_BUILDER(

0 commit comments

Comments
 (0)