-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathimpala_client.py
executable file
·548 lines (478 loc) · 20.9 KB
/
impala_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sasl
import sys
import time
from beeswaxd import BeeswaxService
from beeswaxd.BeeswaxService import QueryState
from ExecStats.ttypes import TExecStats
from ImpalaService import ImpalaService
from ErrorCodes.ttypes import TErrorCode
from Status.ttypes import TStatus
from thrift.protocol import TBinaryProtocol
from thrift_sasl import TSaslClientTransport
from thrift.transport.TSocket import TSocket
from thrift.transport.TTransport import TBufferedTransport, TTransportException
from thrift.Thrift import TApplicationException, TException
class RpcStatus:
"""Convenience enum to describe Rpc return statuses"""
OK = 0
ERROR = 1
class RPCException(Exception):
def __init__(self, value=""):
self.value = value
def __str__(self):
return self.value
class QueryStateException(Exception):
def __init__(self, value=""):
self.value = value
def __str__(self):
return self.value
class DisconnectedException(Exception):
def __init__(self, value=""):
self.value = value
def __str__(self):
return self.value
class QueryCancelledByShellException(Exception): pass
def print_to_stderr(message):
print >> sys.stderr, message
class ImpalaClient(object):
def __init__(self, impalad, kerberos_host_fqdn, use_kerberos=False,
kerberos_service_name="impala", use_ssl=False, ca_cert=None, user=None,
ldap_password=None, use_ldap=False, client_connect_timeout_ms=60000,
verbose=True):
self.connected = False
self.impalad_host = impalad[0].encode('ascii', 'ignore')
self.impalad_port = int(impalad[1])
self.kerberos_host_fqdn = kerberos_host_fqdn
self.imp_service = None
self.transport = None
self.use_kerberos = use_kerberos
self.kerberos_service_name = kerberos_service_name
self.use_ssl = use_ssl
self.ca_cert = ca_cert
self.user, self.ldap_password = user, ldap_password
self.use_ldap = use_ldap
self.client_connect_timeout_ms = int(client_connect_timeout_ms)
self.default_query_options = {}
self.query_option_levels = {}
self.query_state = QueryState._NAMES_TO_VALUES
self.fetch_batch_size = 1024
# This is set from ImpalaShell's signal handler when a query is cancelled
# from command line via CTRL+C. It is used to suppress error messages of
# query cancellation.
self.is_query_cancelled = False
self.verbose = verbose
def _options_to_string_list(self, set_query_options):
return ["%s=%s" % (k, v) for (k, v) in set_query_options.iteritems()]
def build_default_query_options_dict(self):
"""The default query options are retrieved from a rpc call, and are dependent
on the impalad to which a connection has been established. They need to be
refreshed each time a connection is made. This is particularly helpful when
there is a version mismatch between the shell and the impalad.
"""
try:
get_default_query_options = self.imp_service.get_default_configuration(False)
except:
return
rpc_result = self._do_rpc(lambda: get_default_query_options)
options, status = rpc_result
if status != RpcStatus.OK:
raise RPCException("Unable to retrieve default query options")
for option in options:
self.default_query_options[option.key.upper()] = option.value
# If connected to an Impala that predates IMPALA-2181 then the received options
# wouldn't contain a level attribute. In this case the query_option_levels
# map is left empty.
if option.level is not None:
self.query_option_levels[option.key.upper()] = option.level
def build_summary_table(self, summary, idx, is_fragment_root, indent_level,
new_indent_level, output):
"""Direct translation of Coordinator::PrintExecSummary() to recursively build a list
of rows of summary statistics, one per exec node
summary: the TExecSummary object that contains all the summary data
idx: the index of the node to print
is_fragment_root: true if the node to print is the root of a fragment (and therefore
feeds into an exchange)
indent_level: the number of spaces to print before writing the node's label, to give
the appearance of a tree. The 0th child of a node has the same indent_level as its
parent. All other children have an indent_level of one greater than their parent.
output: the list of rows into which to append the rows produced for this node and its
children.
Returns the index of the next exec node in summary.exec_nodes that should be
processed, used internally to this method only.
NOTE: This is duplicated in impala_beeswax.py, and changes made here should also be
made there. TODO: refactor into a shared library. (IMPALA-5792)
"""
attrs = ["latency_ns", "cpu_time_ns", "cardinality", "memory_used"]
# Initialise aggregate and maximum stats
agg_stats, max_stats = TExecStats(), TExecStats()
for attr in attrs:
setattr(agg_stats, attr, 0)
setattr(max_stats, attr, 0)
node = summary.nodes[idx]
if node.exec_stats is not None:
for stats in node.exec_stats:
for attr in attrs:
val = getattr(stats, attr)
if val is not None:
setattr(agg_stats, attr, getattr(agg_stats, attr) + val)
setattr(max_stats, attr, max(getattr(max_stats, attr), val))
if node.exec_stats is not None and node.exec_stats:
avg_time = agg_stats.latency_ns / len(node.exec_stats)
else:
avg_time = 0
# If the node is a broadcast-receiving exchange node, the cardinality of rows produced
# is the max over all instances (which should all have received the same number of
# rows). Otherwise, the cardinality is the sum over all instances which process
# disjoint partitions.
if node.is_broadcast:
cardinality = max_stats.cardinality
else:
cardinality = agg_stats.cardinality
est_stats = node.estimated_stats
label_prefix = ""
if indent_level > 0:
label_prefix = "|"
label_prefix += " |" * (indent_level - 1)
if new_indent_level:
label_prefix += "--"
else:
label_prefix += " "
def prettyprint(val, units, divisor):
for unit in units:
if val < divisor:
if unit == units[0]:
return "%d%s" % (val, unit)
else:
return "%3.2f%s" % (val, unit)
val /= divisor
def prettyprint_bytes(byte_val):
return prettyprint(byte_val, [' B', ' KB', ' MB', ' GB', ' TB'], 1024.0)
def prettyprint_units(unit_val):
return prettyprint(unit_val, ["", "K", "M", "B"], 1000.0)
def prettyprint_time(time_val):
return prettyprint(time_val, ["ns", "us", "ms", "s"], 1000.0)
hosts = 0
if node.exec_stats is not None:
hosts = len(node.exec_stats)
is_sink = node.node_id == -1
row = [ label_prefix + node.label,
hosts,
prettyprint_time(avg_time),
prettyprint_time(max_stats.latency_ns),
"" if is_sink else prettyprint_units(cardinality),
"" if is_sink else prettyprint_units(est_stats.cardinality),
prettyprint_bytes(max_stats.memory_used),
prettyprint_bytes(est_stats.memory_used),
node.label_detail ]
output.append(row)
try:
sender_idx = summary.exch_to_sender_map[idx]
# This is an exchange node, so the sender is a fragment root, and should be printed
# next.
self.build_summary_table(summary, sender_idx, True, indent_level, False, output)
except (KeyError, TypeError):
# Fall through if idx not in map, or if exch_to_sender_map itself is not set
pass
idx += 1
if node.num_children > 0:
first_child_output = []
idx = \
self.build_summary_table(
summary, idx, False, indent_level, False, first_child_output)
for child_idx in xrange(1, node.num_children):
# All other children are indented (we only have 0, 1 or 2 children for every exec
# node at the moment)
idx = self.build_summary_table(
summary, idx, False, indent_level + 1, True, output)
output += first_child_output
return idx
def is_connected(self):
"""Returns True if the current Impala connection is alive and False otherwise."""
if self.connected:
try:
return self.imp_service.PingImpalaService()
except TException:
return False
else:
return False
def connect(self):
"""Creates a connection to an Impalad instance
The instance of the impala service is then pinged to
test the connection and get back the server version
"""
if self.transport is not None:
self.transport.close()
self.transport = None
self.connected = False
sock, self.transport = self._get_socket_and_transport()
if self.client_connect_timeout_ms > 0:
sock.setTimeout(self.client_connect_timeout_ms)
self.transport.open()
if self.verbose:
print_to_stderr('Opened TCP connection to %s:%s' % (self.impalad_host,
self.impalad_port))
# Setting a timeout of None disables timeouts on sockets
sock.setTimeout(None)
protocol = TBinaryProtocol.TBinaryProtocol(self.transport)
self.imp_service = ImpalaService.Client(protocol)
result = self.ping_impala_service()
self.connected = True
return result
def ping_impala_service(self):
return self.imp_service.PingImpalaService()
def close_connection(self):
"""Close the transport if it's still open"""
if self.transport:
self.transport.close()
def _get_socket_and_transport(self):
"""Create a Transport.
A non-kerberized impalad just needs a simple buffered transport. For
the kerberized version, a sasl transport is created.
If SSL is enabled, a TSSLSocket underlies the transport stack; otherwise a TSocket
is used.
This function returns the socket and the transport object.
"""
if self.use_ssl:
# TSSLSocket needs the ssl module, which may not be standard on all Operating
# Systems. Only attempt to import TSSLSocket if the user wants an SSL connection.
from TSSLSocketWithWildcardSAN import TSSLSocketWithWildcardSAN
# sasl does not accept unicode strings, explicitly encode the string into ascii.
# The kerberos_host_fqdn option exposes the SASL client's hostname attribute to
# the user. impala-shell checks to ensure this host matches the host in the kerberos
# principal. So when a load balancer is configured to be used, its hostname is expected by
# impala-shell. Setting this option to the load balancer hostname allows impala-shell to
# connect directly to an impalad.
if self.kerberos_host_fqdn is not None:
sasl_host = self.kerberos_host_fqdn.split(':')[0].encode('ascii', 'ignore')
else:
sasl_host = self.impalad_host
# Always use the hostname and port passed in to -i / --impalad as the host for the purpose of
# creating the actual socket.
sock_host = self.impalad_host
sock_port = self.impalad_port
if self.use_ssl:
if self.ca_cert is None:
# No CA cert means don't try to verify the certificate
sock = TSSLSocketWithWildcardSAN(sock_host, sock_port, validate=False)
else:
sock = TSSLSocketWithWildcardSAN(sock_host, sock_port, validate=True, ca_certs=self.ca_cert)
else:
sock = TSocket(sock_host, sock_port)
if not (self.use_ldap or self.use_kerberos):
return sock, TBufferedTransport(sock)
# Initializes a sasl client
def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr("host", sasl_host)
if self.use_ldap:
sasl_client.setAttr("username", self.user)
sasl_client.setAttr("password", self.ldap_password)
else:
sasl_client.setAttr("service", self.kerberos_service_name)
sasl_client.init()
return sasl_client
# GSSASPI is the underlying mechanism used by kerberos to authenticate.
if self.use_kerberos:
return sock, TSaslClientTransport(sasl_factory, "GSSAPI", sock)
else:
return sock, TSaslClientTransport(sasl_factory, "PLAIN", sock)
def create_beeswax_query(self, query_str, set_query_options):
"""Create a beeswax query object from a query string"""
query = BeeswaxService.Query()
query.hadoop_user = self.user
query.query = query_str
query.configuration = self._options_to_string_list(set_query_options)
return query
def execute_query(self, query):
self.is_query_cancelled = False
rpc_result = self._do_rpc(lambda: self.imp_service.query(query))
last_query_handle, status = rpc_result
if status != RpcStatus.OK:
raise RPCException("Error executing the query")
return last_query_handle
def wait_to_finish(self, last_query_handle, periodic_callback=None):
loop_start = time.time()
while True:
query_state = self.get_query_state(last_query_handle)
if query_state == self.query_state["FINISHED"]:
break
elif query_state == self.query_state["EXCEPTION"]:
if self.connected:
raise QueryStateException(self.get_error_log(last_query_handle))
else:
raise DisconnectedException("Not connected to impalad.")
if periodic_callback is not None: periodic_callback()
time.sleep(self._get_sleep_interval(loop_start))
def fetch(self, query_handle):
"""Fetch all the results.
This function returns a generator to create an iterable of the result rows.
"""
result_rows = []
while True:
rpc_result = self._do_rpc(
lambda: self.imp_service.fetch(query_handle, False,
self.fetch_batch_size))
result, status = rpc_result
if status != RpcStatus.OK:
raise RPCException()
result_rows.extend(result.data)
if len(result_rows) >= self.fetch_batch_size or not result.has_more:
rows = [row.split('\t') for row in result_rows]
result_rows = []
yield rows
if not result.has_more:
break
def close_dml(self, last_query_handle):
"""Fetches the results of a DML query. Returns a tuple containing the
number of rows modified and the number of row errors, in that order. If the DML
operation doesn't return 'num_row_errors', then the second element in the tuple
is None."""
rpc_result = self._do_rpc(
lambda: self.imp_service.CloseInsert(last_query_handle))
insert_result, status = rpc_result
if status != RpcStatus.OK:
raise RPCException()
num_rows = sum([int(k) for k in insert_result.rows_modified.values()])
return (num_rows, insert_result.num_row_errors)
def close_query(self, last_query_handle, query_handle_closed=False):
"""Close the query handle"""
# Make closing a query handle idempotent
if query_handle_closed:
return True
rpc_result = self._do_rpc(lambda: self.imp_service.close(last_query_handle))
_, status = rpc_result
return status == RpcStatus.OK
def cancel_query(self, last_query_handle, query_handle_closed=False):
"""Cancel a query on a keyboard interrupt from the shell."""
# Cancel sets query_state to EXCEPTION before calling cancel() in the
# co-ordinator, so we don't need to wait.
if query_handle_closed:
return True
rpc_result = self._do_rpc(lambda: self.imp_service.Cancel(last_query_handle),
False)
_, status = rpc_result
return status == RpcStatus.OK
def get_query_state(self, last_query_handle):
rpc_result = self._do_rpc(
lambda: self.imp_service.get_state(last_query_handle))
state, status = rpc_result
if status != RpcStatus.OK:
return self.query_state["EXCEPTION"]
return state
def get_runtime_profile(self, last_query_handle):
rpc_result = self._do_rpc(
lambda: self.imp_service.GetRuntimeProfile(last_query_handle))
profile, status = rpc_result
if status == RpcStatus.OK and profile:
return profile
def get_summary(self, last_query_handle):
"""Calls GetExecSummary() for the last query handle"""
rpc_result = self._do_rpc(
lambda: self.imp_service.GetExecSummary(last_query_handle))
summary, status = rpc_result
if status == RpcStatus.OK and summary:
return summary
return None
def _do_rpc(self, rpc, suppress_error_on_cancel=True):
"""Executes the provided callable."""
if not self.connected:
raise DisconnectedException("Not connected (use CONNECT to establish a connection)")
return None, RpcStatus.ERROR
try:
ret = rpc()
status = RpcStatus.OK
# TODO: In the future more advanced error detection/handling can be done based on
# the TStatus return value. For now, just print any error(s) that were encountered
# and validate the result of the operation was a success.
if ret is not None and isinstance(ret, TStatus):
if ret.status_code != TErrorCode.OK:
if ret.error_msgs:
raise RPCException ('RPC Error: %s' % '\n'.join(ret.error_msgs))
status = RpcStatus.ERROR
return ret, status
except BeeswaxService.QueryNotFoundException:
if suppress_error_on_cancel and self.is_query_cancelled:
raise QueryCancelledByShellException()
raise QueryStateException('Error: Stale query handle')
# beeswaxException prints out the entire object, printing
# just the message is far more readable/helpful.
except BeeswaxService.BeeswaxException, b:
# Suppress the errors from cancelling a query that is in fetch state
if suppress_error_on_cancel and self.is_query_cancelled:
raise QueryCancelledByShellException()
raise RPCException("ERROR: %s" % b.message)
except TTransportException, e:
# issue with the connection with the impalad
raise DisconnectedException("Error communicating with impalad: %s" % e)
except TApplicationException, t:
# Suppress the errors from cancelling a query that is in waiting_to_finish
# state
if suppress_error_on_cancel and self.is_query_cancelled:
raise QueryCancelledByShellException()
raise RPCException("Application Exception : %s" % t)
return None, RpcStatus.ERROR
def _get_sleep_interval(self, start_time):
"""Returns a step function of time to sleep in seconds before polling
again. Maximum sleep is 1s, minimum is 0.1s"""
elapsed = time.time() - start_time
if elapsed < 10.0:
return 0.1
elif elapsed < 60.0:
return 0.5
return 1.0
def get_column_names(self, last_query_handle):
rpc_result = self._do_rpc(
lambda: self.imp_service.get_results_metadata(last_query_handle))
metadata, _ = rpc_result
if not metadata is None:
return [fs.name for fs in metadata.schema.fieldSchemas]
def expect_result_metadata(self, query_str):
""" Given a query string, return True if impalad expects result metadata"""
excluded_query_types = ['use']
if True in set(map(query_str.startswith, excluded_query_types)):
return False
return True
def get_warn_or_error_log(self, last_query_handle, warn):
"""Returns all messages from the error log prepended with 'WARNINGS:' or 'ERROR:' for
last_query_handle, depending on whether warn is True or False. Note that the error
log may contain messages that are not errors (e.g. warnings)."""
if last_query_handle is None:
return "Query could not be executed"
rpc_result = self._do_rpc(
lambda: self.imp_service.get_log(last_query_handle.log_context))
log, status = rpc_result
if status != RpcStatus.OK:
type_str = "warn" if warn is True else "error"
return "Failed to get %s log: %s" % (type_str, status)
if log and log.strip():
type_str = "WARNINGS" if warn is True else "ERROR"
return "%s: %s" % (type_str, log)
return ""
def get_warning_log(self, last_query_handle):
"""Returns all messages from the error log prepended with 'WARNINGS:' for
last_query_handle. Note that the error log may contain messages that are not errors
(e.g. warnings)."""
return self.get_warn_or_error_log(last_query_handle, True)
def get_error_log(self, last_query_handle):
"""Returns all messages from the error log prepended with 'ERROR:' for
last_query_handle. Note that the error log may contain messages that are not errors
(e.g. warnings)."""
return self.get_warn_or_error_log(last_query_handle, False)