diff --git a/tensorboard/backend/http_util.py b/tensorboard/backend/http_util.py index 6af2145cbe..63e6017514 100644 --- a/tensorboard/backend/http_util.py +++ b/tensorboard/backend/http_util.py @@ -22,12 +22,13 @@ import gzip import json import re +import struct import time import wsgiref.handlers import six import tensorflow as tf -from werkzeug import wrappers +import werkzeug from tensorboard.backend import json_util @@ -119,19 +120,32 @@ def Respond(request, content = tf.compat.as_bytes(content, charset) if textual and not charset_match and mimetype not in _JSON_MIMETYPES: content_type += '; charset=' + charset - if (not content_encoding and textual and - _ALLOWS_GZIP_PATTERN.search(request.headers.get('Accept-Encoding', ''))): + gzip_accepted = _ALLOWS_GZIP_PATTERN.search( + request.headers.get('Accept-Encoding', '')) + # Automatically gzip uncompressed text data if accepted. + if textual and not content_encoding and gzip_accepted: out = six.BytesIO() - f = gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3) - f.write(content) - f.close() + # Set mtime to zero to make payload for a given input deterministic. + with gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3, mtime=0) as f: + f.write(content) content = out.getvalue() content_encoding = 'gzip' - if request.method == 'HEAD': - content = '' - headers = [] - headers.append(('Content-Length', str(len(content)))) + content_length = len(content) + direct_passthrough = False + # Automatically streamwise-gunzip precompressed data if not accepted. + if content_encoding == 'gzip' and not gzip_accepted: + gzip_file = gzip.GzipFile(fileobj=six.BytesIO(content), mode='rb') + # Last 4 bytes of gzip formatted data (little-endian) store the original + # content length mod 2^32; we just assume it's the content length. That + # means we can't streamwise-gunzip >4 GB precompressed file; this is ok. + content_length = struct.unpack(' 0: @@ -142,5 +156,9 @@ def Respond(request, headers.append(('Expires', '0')) headers.append(('Cache-Control', 'no-cache, must-revalidate')) - return wrappers.Response( - response=content, status=code, headers=headers, content_type=content_type) + if request.method == 'HEAD': + content = None + + return werkzeug.wrappers.Response( + response=content, status=code, headers=headers, content_type=content_type, + direct_passthrough=direct_passthrough) diff --git a/tensorboard/backend/http_util_test.py b/tensorboard/backend/http_util_test.py index b686b408af..0eb7f573be 100644 --- a/tensorboard/backend/http_util_test.py +++ b/tensorboard/backend/http_util_test.py @@ -21,6 +21,7 @@ from __future__ import unicode_literals import gzip +import struct import six import tensorflow as tf @@ -36,6 +37,7 @@ def testHelloWorld(self): r = http_util.Respond(q, 'hello world', 'text/html') self.assertEqual(r.status_code, 200) self.assertEqual(r.response, [six.b('hello world')]) + self.assertEqual(r.headers.get('Content-Length'), '18') def testHeadRequest_doesNotWrite(self): builder = wtest.EnvironBuilder(method='HEAD') @@ -43,7 +45,8 @@ def testHeadRequest_doesNotWrite(self): request = wrappers.Request(env) r = http_util.Respond(request, 'hello world', 'text/html') self.assertEqual(r.status_code, 200) - self.assertEqual(r.response, [six.b('')]) + self.assertEqual(r.response, []) + self.assertEqual(r.headers.get('Content-Length'), '18') def testPlainText_appendsUtf8ToContentType(self): q = wrappers.Request(wtest.EnvironBuilder().get_environ()) @@ -136,6 +139,36 @@ def testAcceptGzip_compressesResponse(self): self.assertEqual( r.response, [fall_of_hyperion_canto1_stanza1.encode('utf-8')]) + def testAcceptGzip_alreadyCompressed_sendsPrecompressedResponse(self): + gzip_text = _gzip(b'hello hello hello world') + e = wtest.EnvironBuilder(headers={'Accept-Encoding': 'gzip'}).get_environ() + q = wrappers.Request(e) + r = http_util.Respond(q, gzip_text, 'text/plain', content_encoding='gzip') + self.assertEqual(r.response, [gzip_text]) # Still singly zipped + + def testPrecompressedResponse_noAcceptGzip_decompressesResponse(self): + orig_text = b'hello hello hello world' + gzip_text = _gzip(orig_text) + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, gzip_text, 'text/plain', content_encoding='gzip') + # Streaming gunzip produces file-wrapper application iterator as response, + # so rejoin it into the full response before comparison. + full_response = b''.join(r.response) + self.assertEqual(full_response, orig_text) + + def testPrecompressedResponse_streamingDecompression_catchesBadSize(self): + orig_text = b'hello hello hello world' + gzip_text = _gzip(orig_text) + # Corrupt the gzipped data's stored content size (last 4 bytes). + bad_text = gzip_text[:-4] + _bitflip(gzip_text[-4:]) + q = wrappers.Request(wtest.EnvironBuilder().get_environ()) + r = http_util.Respond(q, bad_text, 'text/plain', content_encoding='gzip') + # Streaming gunzip defers actual unzipping until response is used; once + # we iterate over the whole file-wrapper application iterator, the + # underlying GzipFile should be closed, and throw the size check error. + with six.assertRaisesRegex(self, IOError, 'Incorrect length'): + _ = list(r.response) + def testJson_getsAutoSerialized(self): q = wrappers.Request(wtest.EnvironBuilder().get_environ()) r = http_util.Respond(q, [1, 2, 3], 'application/json') @@ -147,9 +180,21 @@ def testExpires_setsCruiseControl(self): self.assertEqual(r.headers.get('Cache-Control'), 'private, max-age=60') +def _gzip(bs): + out = six.BytesIO() + with gzip.GzipFile(fileobj=out, mode='wb') as f: + f.write(bs) + return out.getvalue() + + def _gunzip(bs): - return gzip.GzipFile('', 'rb', 9, six.BytesIO(bs)).read() + with gzip.GzipFile(fileobj=six.BytesIO(bs), mode='rb') as f: + return f.read() +def _bitflip(bs): + # Return bytestring with all its bits flipped. + return b''.join(struct.pack('B', 0xFF ^ struct.unpack_from('B', bs, i)[0]) + for i in range(len(bs))) if __name__ == '__main__': tf.test.main() diff --git a/tensorboard/plugins/core/core_plugin.py b/tensorboard/plugins/core/core_plugin.py index a8d25101ad..0ea9129ed8 100644 --- a/tensorboard/plugins/core/core_plugin.py +++ b/tensorboard/plugins/core/core_plugin.py @@ -19,9 +19,11 @@ from __future__ import print_function import functools +import gzip import mimetypes import zipfile +import six import tensorflow as tf from werkzeug import utils from werkzeug import wrappers @@ -65,12 +67,13 @@ def get_plugin_apps(self): '/images': self._redirect_to_index, } if self._assets_zip_provider: - apps['/'] = functools.partial(self._serve_asset, 'index.html') with self._assets_zip_provider() as fp: with zipfile.ZipFile(fp) as zip_: - for info in zip_.infolist(): - path = info.filename - apps['/' + path] = functools.partial(self._serve_asset, path) + for path in zip_.namelist(): + gzipped_asset_bytes = _gzip(zip_.read(path)) + apps['/' + path] = functools.partial( + self._serve_asset, path, gzipped_asset_bytes) + apps['/'] = apps['/index.html'] return apps @wrappers.Request.application @@ -82,14 +85,11 @@ def _redirect_to_index(self, unused_request): return utils.redirect('/') @wrappers.Request.application - def _serve_asset(self, path, request): - """Serves a static asset from the zip file.""" + def _serve_asset(self, path, gzipped_asset_bytes, request): + """Serves a pre-gzipped static asset from the zip file.""" mimetype = mimetypes.guess_type(path)[0] or 'application/octet-stream' - with self._assets_zip_provider() as fp: - with zipfile.ZipFile(fp) as zip_: - with zip_.open(path) as file_: - html = file_.read() - return http_util.Respond(request, html, mimetype) + return http_util.Respond( + request, gzipped_asset_bytes, mimetype, content_encoding='gzip') @wrappers.Request.application def _serve_logdir(self, request): @@ -129,3 +129,11 @@ def get_first_event_timestamp(run_name): } run_names.sort(key=first_event_timestamps.get) return http_util.Respond(request, run_names, 'application/json') + + +def _gzip(bytestring): + out = six.BytesIO() + # Set mtime to zero for deterministic results across TensorBoard launches. + with gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3, mtime=0) as f: + f.write(bytestring) + return out.getvalue()