Skip to content

Change histogram SQL query to use linspaced sampling #1022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorboard/plugins/distribution/distributions_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class DistributionsPlugin(base_plugin.TBPlugin):

plugin_name = 'distributions'

# Use a round number + 1 since sampling includes both start and end steps,
# so N+1 samples corresponds to dividing the step sequence into N intervals.
SAMPLE_SIZE = 501

def __init__(self, context):
"""Instantiates DistributionsPlugin via TensorBoard core.

Expand All @@ -67,7 +71,7 @@ def is_active(self):
def distributions_impl(self, tag, run):
"""Result of the form `(body, mime_type)`, or `ValueError`."""
(histograms, mime_type) = self._histograms_plugin.histograms_impl(
tag, run, downsample_to=None)
tag, run, downsample_to=self.SAMPLE_SIZE)
return ([self._compress(histogram) for histogram in histograms],
mime_type)

Expand Down
107 changes: 64 additions & 43 deletions tensorboard/plugins/histogram/histograms_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@

import collections
import random
import six
from werkzeug import wrappers

import numpy as np
import six
import tensorflow as tf
from werkzeug import wrappers

from tensorboard import plugin_util
from tensorboard.backend import http_util
Expand All @@ -47,6 +47,10 @@ class HistogramsPlugin(base_plugin.TBPlugin):

plugin_name = metadata.PLUGIN_NAME

# Use a round number + 1 since sampling includes both start and end steps,
# so N+1 samples corresponds to dividing the step sequence into N intervals.
SAMPLE_SIZE = 51

def __init__(self, context):
"""Instantiates HistogramsPlugin via TensorBoard core.

Expand Down Expand Up @@ -119,7 +123,7 @@ def index_impl(self):

return result

def histograms_impl(self, tag, run, downsample_to=50):
def histograms_impl(self, tag, run, downsample_to=None):
"""Result of the form `(body, mime_type)`, or `ValueError`.

At most `downsample_to` events will be returned. If this value is
Expand All @@ -128,45 +132,61 @@ def histograms_impl(self, tag, run, downsample_to=50):
if self._db_connection_provider:
# Serve data from the database.
db = self._db_connection_provider()
# We select for steps greater than -1 because the writer inserts
# placeholder rows en masse. The check for step filters out those rows.
query = '''
SELECT
Tensors.computed_time AS computed_time,
Tensors.step AS step,
Tensors.data AS data,
Tensors.dtype AS dtype,
Tensors.shape AS shape
FROM Tensors
JOIN Tags
ON Tensors.series = Tags.tag_id
JOIN Runs
ON Tags.run_id = Runs.run_id
WHERE
Runs.run_name = ?
AND Tags.tag_name = ?
AND Tags.plugin_name = ?
AND Tensors.step > -1
'''
if downsample_to is not None:
# Wrap the query in an outer one that samples.
query = '''
SELECT *
FROM
(%(query)s
ORDER BY RANDOM()
LIMIT %(downsample_to)d)
''' % {
'query': query,
'downsample_to': downsample_to,
}
query = '''
%s
ORDER BY step
''' % query
cursor = db.execute(query, (run, tag, metadata.PLUGIN_NAME))
events = [(row[0], row[1], self._get_values(row[2], row[3], row[4]))
for row in cursor]
cursor = db.cursor()
# Prefetch the tag ID matching this run and tag.
cursor.execute(
'''
SELECT
tag_id
FROM Tags
JOIN Runs USING (run_id)
WHERE
Runs.run_name = :run
AND Tags.tag_name = :tag
AND Tags.plugin_name = :plugin
''',
{'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME})
row = cursor.fetchone()
if not row:
raise ValueError('No histogram tag %r for run %r' % (tag, run))
(tag_id,) = row
# Fetch tensor values, optionally with linear-spaced sampling by step.
# For steps ranging from s_min to s_max and sample size k, this query
# divides the range into k - 1 equal-sized intervals and returns the
# lowest step at or above each of the k interval boundaries (which always
# includes s_min and s_max, and may be fewer than k results if there are
# intervals where no steps are present). For contiguous steps the results
# can be formally expressed as the following:
# [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)]
cursor.execute(
'''
SELECT
MIN(step) AS step,
computed_time,
data,
dtype,
shape
FROM Tensors
INNER JOIN (
SELECT
MIN(step) AS min_step,
MAX(step) AS max_step
FROM Tensors
/* Filter out NULL so we can use TensorSeriesStepIndex. */
WHERE series = :tag_id AND step IS NOT NULL
)
/* Ensure we omit reserved rows, which have NULL step values. */
WHERE series = :tag_id AND step IS NOT NULL
/* Bucket rows into sample_size linearly spaced buckets, or do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the lucid comments.

no sampling if sample_size is NULL. */
GROUP BY
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly, this GROUP BY makes it so that we only sample 1 event at each sampled step, right? ie, either the min or max per bucket that you had noted earlier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the GROUP BY basically partitions the series by a "bucket index" so that we divide the range into k - 1 even sized intervals, and then within each interval we take the minimum aka first step in that interval. The way we do the math, the max step is always assigned to its own singleton interval with bucket index k - 1, so it also get selected basically as the lower bound of a kth interval that just happens not to contain any other steps (so conceptually, I think of the selected steps as being the set of lower and upper bounds of k - 1 intervals instead).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note also that if the sample_size is null, we effectively set it to max_step - min_step + 1 which is the total number of distinct steps. As it happens that means our formal now has (max_step - min_step) in both numerator and denominator, so they cancel and each step gets assigned a unique bucket index of step - min_step.

If sample_size is greater than the total number of steps, the bucket index for each step is assigned as m * step + C where m > 1, so each step will still get a unique bucket index and hence all steps will be selected.

IFNULL(:sample_size - 1, max_step - min_step)
* (step - min_step) / (max_step - min_step)
ORDER BY step
''',
{'tag_id': tag_id, 'sample_size': downsample_to})
events = [(computed_time, step, self._get_values(data, dtype, shape))
for step, computed_time, data, dtype, shape in cursor]
else:
# Serve data from events files.
try:
Expand Down Expand Up @@ -204,7 +224,8 @@ def histograms_route(self, request):
tag = request.args.get('tag')
run = request.args.get('run')
try:
(body, mime_type) = self.histograms_impl(tag, run)
(body, mime_type) = self.histograms_impl(
tag, run, downsample_to=self.SAMPLE_SIZE)
code = 200
except ValueError as e:
(body, mime_type) = (str(e), 'text/plain')
Expand Down