@@ -51,6 +51,7 @@ class TfqSimulateExpectationOpCuQuantum : public tensorflow::OpKernel {
51
51
: OpKernel(context) {}
52
52
53
53
void Compute (tensorflow::OpKernelContext* context) override {
54
+
54
55
// TODO (mbbrough): add more dimension checks for other inputs here.
55
56
const int num_inputs = context->num_inputs ();
56
57
OP_REQUIRES (context, num_inputs == 4 ,
@@ -116,13 +117,10 @@ class TfqSimulateExpectationOpCuQuantum : public tensorflow::OpKernel {
116
117
// create handles for simulator
117
118
cublasCreate (&cublas_handle_);
118
119
custatevecCreate (&custatevec_handle_);
119
- if (max_num_qubits >= 26 || programs.size () == 1 ) {
120
- ComputeLarge (num_qubits, fused_circuits, pauli_sums, context,
121
- &output_tensor); // HOW TO manage extraWorkspace size?
122
- } else {
123
- ComputeSmall (num_qubits, max_num_qubits, fused_circuits, pauli_sums,
124
- context, &output_tensor);
125
- }
120
+
121
+ ComputeLarge (num_qubits, fused_circuits, pauli_sums, context,
122
+ &output_tensor); // HOW TO manage extraWorkspace size?
123
+
126
124
// destroy handles in sync with simulator lifetime
127
125
cublasDestroy (cublas_handle_);
128
126
custatevecDestroy (custatevec_handle_);
@@ -186,76 +184,6 @@ class TfqSimulateExpectationOpCuQuantum : public tensorflow::OpKernel {
186
184
}
187
185
}
188
186
}
189
-
190
- void ComputeSmall (
191
- const std::vector<int >& num_qubits, const int max_num_qubits,
192
- const std::vector<std::vector<qsim::GateFused<QsimGate>>>& fused_circuits,
193
- const std::vector<std::vector<PauliSum>>& pauli_sums,
194
- tensorflow::OpKernelContext* context,
195
- tensorflow::TTypes<float , 1 >::Matrix* output_tensor) {
196
- using Simulator = qsim::SimulatorCuStateVec<float >;
197
- using StateSpace = Simulator::StateSpace;
198
-
199
- const int output_dim_op_size = output_tensor->dimension (1 );
200
-
201
- Status compute_status = Status::OK ();
202
- auto c_lock = tensorflow::mutex ();
203
- auto DoWork = [&](int start, int end) {
204
- int old_batch_index = -2 ;
205
- int cur_batch_index = -1 ;
206
- int largest_nq = 1 ;
207
- int cur_op_index;
208
-
209
- // Launch custatevec, begin simulation.
210
- auto sim = Simulator (cublas_handle_, custatevec_handle_);
211
- auto ss = StateSpace (cublas_handle_, custatevec_handle_);
212
- auto sv = ss.Create (largest_nq);
213
- auto scratch = ss.Create (largest_nq);
214
- for (int i = start; i < end; i++) {
215
- cur_batch_index = i / output_dim_op_size;
216
- cur_op_index = i % output_dim_op_size;
217
-
218
- const int nq = num_qubits[cur_batch_index];
219
-
220
- // (#679) Just ignore empty program
221
- if (fused_circuits[cur_batch_index].size () == 0 ) {
222
- (*output_tensor)(cur_batch_index, cur_op_index) = -2.0 ;
223
- continue ;
224
- }
225
-
226
- if (cur_batch_index != old_batch_index) {
227
- // We've run into a new state vector we must compute.
228
- // Only compute a new state vector when we have to.
229
- if (nq > largest_nq) {
230
- largest_nq = nq;
231
- sv = ss.Create (largest_nq);
232
- scratch = ss.Create (largest_nq);
233
- }
234
- // no need to update scratch_state since ComputeExpectation
235
- // will take care of things for us.
236
- ss.SetStateZero (sv);
237
- for (int j = 0 ; j < fused_circuits[cur_batch_index].size (); j++) {
238
- qsim::ApplyFusedGate (sim, fused_circuits[cur_batch_index][j], sv);
239
- }
240
- }
241
-
242
- float exp_v = 0.0 ;
243
- NESTED_FN_STATUS_SYNC (
244
- compute_status,
245
- ComputeExpectationQsim (pauli_sums[cur_batch_index][cur_op_index],
246
- sim, ss, sv, scratch, &exp_v),
247
- c_lock);
248
- (*output_tensor)(cur_batch_index, cur_op_index) = exp_v;
249
- old_batch_index = cur_batch_index;
250
- }
251
- };
252
-
253
- const int64_t num_cycles =
254
- 200 * (int64_t (1 ) << static_cast <int64_t >(max_num_qubits));
255
- context->device ()->tensorflow_cpu_worker_threads ()->workers ->ParallelFor (
256
- fused_circuits.size () * output_dim_op_size, num_cycles, DoWork);
257
- OP_REQUIRES_OK (context, compute_status);
258
- }
259
187
};
260
188
261
189
REGISTER_KERNEL_BUILDER (
0 commit comments