-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Change histogram SQL query to use linspaced sampling #1022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,11 +24,11 @@ | |
|
||
import collections | ||
import random | ||
import six | ||
from werkzeug import wrappers | ||
|
||
import numpy as np | ||
import six | ||
import tensorflow as tf | ||
from werkzeug import wrappers | ||
|
||
from tensorboard import plugin_util | ||
from tensorboard.backend import http_util | ||
|
@@ -47,6 +47,10 @@ class HistogramsPlugin(base_plugin.TBPlugin): | |
|
||
plugin_name = metadata.PLUGIN_NAME | ||
|
||
# Use a round number + 1 since sampling includes both start and end steps, | ||
# so N+1 samples corresponds to dividing the step sequence into N intervals. | ||
SAMPLE_SIZE = 51 | ||
|
||
def __init__(self, context): | ||
"""Instantiates HistogramsPlugin via TensorBoard core. | ||
|
||
|
@@ -119,7 +123,7 @@ def index_impl(self): | |
|
||
return result | ||
|
||
def histograms_impl(self, tag, run, downsample_to=50): | ||
def histograms_impl(self, tag, run, downsample_to=None): | ||
"""Result of the form `(body, mime_type)`, or `ValueError`. | ||
|
||
At most `downsample_to` events will be returned. If this value is | ||
|
@@ -128,45 +132,61 @@ def histograms_impl(self, tag, run, downsample_to=50): | |
if self._db_connection_provider: | ||
# Serve data from the database. | ||
db = self._db_connection_provider() | ||
# We select for steps greater than -1 because the writer inserts | ||
# placeholder rows en masse. The check for step filters out those rows. | ||
query = ''' | ||
SELECT | ||
Tensors.computed_time AS computed_time, | ||
Tensors.step AS step, | ||
Tensors.data AS data, | ||
Tensors.dtype AS dtype, | ||
Tensors.shape AS shape | ||
FROM Tensors | ||
JOIN Tags | ||
ON Tensors.series = Tags.tag_id | ||
JOIN Runs | ||
ON Tags.run_id = Runs.run_id | ||
WHERE | ||
Runs.run_name = ? | ||
AND Tags.tag_name = ? | ||
AND Tags.plugin_name = ? | ||
AND Tensors.step > -1 | ||
''' | ||
if downsample_to is not None: | ||
# Wrap the query in an outer one that samples. | ||
query = ''' | ||
SELECT * | ||
FROM | ||
(%(query)s | ||
ORDER BY RANDOM() | ||
LIMIT %(downsample_to)d) | ||
''' % { | ||
'query': query, | ||
'downsample_to': downsample_to, | ||
} | ||
query = ''' | ||
%s | ||
ORDER BY step | ||
''' % query | ||
cursor = db.execute(query, (run, tag, metadata.PLUGIN_NAME)) | ||
events = [(row[0], row[1], self._get_values(row[2], row[3], row[4])) | ||
for row in cursor] | ||
cursor = db.cursor() | ||
# Prefetch the tag ID matching this run and tag. | ||
cursor.execute( | ||
''' | ||
SELECT | ||
tag_id | ||
FROM Tags | ||
JOIN Runs USING (run_id) | ||
WHERE | ||
Runs.run_name = :run | ||
AND Tags.tag_name = :tag | ||
AND Tags.plugin_name = :plugin | ||
''', | ||
{'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME}) | ||
row = cursor.fetchone() | ||
if not row: | ||
raise ValueError('No histogram tag %r for run %r' % (tag, run)) | ||
(tag_id,) = row | ||
# Fetch tensor values, optionally with linear-spaced sampling by step. | ||
# For steps ranging from s_min to s_max and sample size k, this query | ||
# divides the range into k - 1 equal-sized intervals and returns the | ||
# lowest step at or above each of the k interval boundaries (which always | ||
# includes s_min and s_max, and may be fewer than k results if there are | ||
# intervals where no steps are present). For contiguous steps the results | ||
# can be formally expressed as the following: | ||
# [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)] | ||
cursor.execute( | ||
''' | ||
SELECT | ||
MIN(step) AS step, | ||
computed_time, | ||
data, | ||
dtype, | ||
shape | ||
FROM Tensors | ||
INNER JOIN ( | ||
SELECT | ||
MIN(step) AS min_step, | ||
MAX(step) AS max_step | ||
FROM Tensors | ||
/* Filter out NULL so we can use TensorSeriesStepIndex. */ | ||
WHERE series = :tag_id AND step IS NOT NULL | ||
) | ||
/* Ensure we omit reserved rows, which have NULL step values. */ | ||
WHERE series = :tag_id AND step IS NOT NULL | ||
/* Bucket rows into sample_size linearly spaced buckets, or do | ||
no sampling if sample_size is NULL. */ | ||
GROUP BY | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm understanding correctly, this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the GROUP BY basically partitions the series by a "bucket index" so that we divide the range into k - 1 even sized intervals, and then within each interval we take the minimum aka first step in that interval. The way we do the math, the max step is always assigned to its own singleton interval with bucket index k - 1, so it also get selected basically as the lower bound of a kth interval that just happens not to contain any other steps (so conceptually, I think of the selected steps as being the set of lower and upper bounds of k - 1 intervals instead). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note also that if the sample_size is null, we effectively set it to max_step - min_step + 1 which is the total number of distinct steps. As it happens that means our formal now has (max_step - min_step) in both numerator and denominator, so they cancel and each step gets assigned a unique bucket index of step - min_step. If sample_size is greater than the total number of steps, the bucket index for each step is assigned as m * step + C where m > 1, so each step will still get a unique bucket index and hence all steps will be selected. |
||
IFNULL(:sample_size - 1, max_step - min_step) | ||
* (step - min_step) / (max_step - min_step) | ||
ORDER BY step | ||
''', | ||
{'tag_id': tag_id, 'sample_size': downsample_to}) | ||
events = [(computed_time, step, self._get_values(data, dtype, shape)) | ||
for step, computed_time, data, dtype, shape in cursor] | ||
else: | ||
# Serve data from events files. | ||
try: | ||
|
@@ -204,7 +224,8 @@ def histograms_route(self, request): | |
tag = request.args.get('tag') | ||
run = request.args.get('run') | ||
try: | ||
(body, mime_type) = self.histograms_impl(tag, run) | ||
(body, mime_type) = self.histograms_impl( | ||
tag, run, downsample_to=self.SAMPLE_SIZE) | ||
code = 200 | ||
except ValueError as e: | ||
(body, mime_type) = (str(e), 'text/plain') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the lucid comments.