Skip to content

Commit 81f4eb7

Browse files
committed
Add shift and concatenate functions
Fixes tskit-dev#3164 Update python/tskit/tables.py
1 parent 942e383 commit 81f4eb7

File tree

5 files changed

+302
-0
lines changed

5 files changed

+302
-0
lines changed

python/CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
associated with each individual as a numpy array.
1313
(:user:`benjeffery`, :pr:`3153`)
1414

15+
- Add ``shift`` and ``concatenate`` methods to both ``TableCollection`` and
16+
``TreeSequence`` classes, allowing the coordinate system to be shifted and
17+
one tree sequence to be added to the right of another.
18+
(:user:`hyanwong`, :pr:`3165`, :issue:`3164`)
19+
1520

1621
**Fixes**
1722

python/tests/test_table_transforms.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@
3939
# we can remove this.
4040

4141

42+
class TestShift:
43+
# Most testing is done on the TreeSequence methods. Here we just check
44+
# that the TableCollection methods work even if they produce an invalid ts
45+
def test_too_negative(self):
46+
tables = tskit.Tree.generate_comb(2).tree_sequence.dump_tables()
47+
tables.shift(-1)
48+
assert np.min(tables.edges.left) == -1
49+
50+
def test_bad_seq_len(self):
51+
tables = tskit.Tree.generate_comb(2).tree_sequence.dump_tables()
52+
tables.shift(1, sequence_length=0.5)
53+
assert tables.sequence_length == 0.5
54+
assert np.max(tables.edges.right) == 2
55+
56+
4257
def delete_older_definition(tables, time):
4358
node_time = tables.nodes.time
4459
edges = tables.edges.copy()

python/tests/test_topology.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7074,6 +7074,136 @@ def test_failure_with_migrations(self):
70747074
ts.trim()
70757075

70767076

7077+
class TestShift:
7078+
"""
7079+
Test the shift functionality
7080+
"""
7081+
7082+
@pytest.mark.parametrize("shift", [-0.5, 0, 0.5])
7083+
def test_shift(self, shift):
7084+
ts = tskit.Tree.generate_comb(2, span=2).tree_sequence
7085+
tables = ts.dump_tables()
7086+
tables.delete_intervals([[0, 1]], simplify=False)
7087+
tables.sites.add_row(1.5, "A")
7088+
ts = tables.tree_sequence()
7089+
ts = ts.shift(shift)
7090+
assert ts.sequence_length == 2 + shift
7091+
assert np.min(ts.tables.edges.left) == 1 + shift
7092+
assert np.max(ts.tables.edges.right) == 2 + shift
7093+
assert np.all(ts.tables.sites.position == 1.5 + shift)
7094+
7095+
def test_sequence_length(self):
7096+
ts = tskit.Tree.generate_comb(2).tree_sequence
7097+
ts = ts.shift(1, sequence_length=3)
7098+
assert ts.sequence_length == 3
7099+
ts = ts.shift(-1, sequence_length=1)
7100+
assert ts.sequence_length == 1
7101+
7102+
def test_empty(self):
7103+
empty_ts = tskit.TableCollection(1.0).tree_sequence()
7104+
empty_ts = empty_ts.shift(1)
7105+
assert empty_ts.sequence_length == 2
7106+
empty_ts = empty_ts.shift(-1.5)
7107+
assert empty_ts.sequence_length == 0.5
7108+
assert empty_ts.num_nodes == 0
7109+
7110+
def test_provenance(self):
7111+
ts = tskit.Tree.generate_comb(2).tree_sequence
7112+
ts = ts.shift(1, record_provenance=False)
7113+
params = json.loads(ts.provenance(-1).record)["parameters"]
7114+
assert params["command"] != "shift"
7115+
ts = ts.shift(1, sequence_length=9)
7116+
params = json.loads(ts.provenance(-1).record)["parameters"]
7117+
assert params["command"] == "shift"
7118+
assert params["value"] == 1
7119+
assert params["sequence_length"] == 9
7120+
7121+
def test_too_negative(self):
7122+
ts = tskit.Tree.generate_comb(2).tree_sequence
7123+
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_SEQUENCE_LENGTH"):
7124+
ts.shift(-1)
7125+
7126+
def test_bad_seq_len(self):
7127+
ts = tskit.Tree.generate_comb(2).tree_sequence
7128+
with pytest.raises(
7129+
tskit.LibraryError, match="TSK_ERR_RIGHT_GREATER_SEQ_LENGTH"
7130+
):
7131+
ts.shift(1, sequence_length=1)
7132+
7133+
7134+
class TestConcatenate:
7135+
def test_simple(self):
7136+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7137+
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
7138+
assert ts1.num_samples == ts2.num_samples
7139+
assert ts1.num_nodes != ts2.num_nodes
7140+
joint_ts = ts1.concatenate(ts2)
7141+
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5
7142+
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
7143+
assert joint_ts.num_samples == ts1.num_samples
7144+
ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim()
7145+
# Have to simplify here, to remove the redundant nodes
7146+
assert ts3.equals(ts1.simplify(), ignore_provenance=True)
7147+
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
7148+
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
7149+
7150+
def test_empty(self):
7151+
empty_ts = tskit.TableCollection(10).tree_sequence()
7152+
ts = empty_ts.concatenate(empty_ts)
7153+
assert ts.num_nodes == 0
7154+
assert ts.sequence_length == 20
7155+
7156+
def test_samples_at_end(self):
7157+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7158+
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
7159+
# reverse the node order
7160+
ts1 = ts1.subset(np.arange(ts1.num_nodes)[::-1])
7161+
assert ts1.num_samples == ts2.num_samples
7162+
assert np.all(ts1.samples() != ts2.samples())
7163+
joint_ts = ts1.concatenate(ts2)
7164+
assert joint_ts.num_samples == ts1.num_samples
7165+
assert np.all(joint_ts.samples() == ts1.samples())
7166+
7167+
def test_internal_samples(self):
7168+
tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables()
7169+
nodes_flags = tables.nodes.flags
7170+
nodes_flags[:] = tskit.NODE_IS_SAMPLE
7171+
nodes_flags[-1] = 0 # Only root is not a sample
7172+
tables.nodes.flags = nodes_flags
7173+
ts = tables.tree_sequence()
7174+
joint_ts = ts.concatenate(ts)
7175+
assert joint_ts.num_samples == ts.num_samples
7176+
assert joint_ts.num_nodes == ts.num_nodes + 1
7177+
assert joint_ts.sequence_length == ts.sequence_length * 2
7178+
7179+
def test_some_shared_samples(self):
7180+
ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence
7181+
ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence
7182+
shared = np.full(ts2.num_nodes, tskit.NULL)
7183+
shared[0] = 1
7184+
shared[1] = 0
7185+
joint_ts = ts1.concatenate(ts2, node_mapping=shared)
7186+
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
7187+
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
7188+
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
7189+
7190+
def test_provenance(self):
7191+
ts = tskit.Tree.generate_comb(2).tree_sequence
7192+
ts = ts.concatenate(ts, record_provenance=False)
7193+
params = json.loads(ts.provenance(-1).record)["parameters"]
7194+
assert params["command"] != "concatenate"
7195+
7196+
ts = ts.concatenate(ts)
7197+
params = json.loads(ts.provenance(-1).record)["parameters"]
7198+
assert params["command"] == "concatenate"
7199+
7200+
def test_unequal_samples(self):
7201+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7202+
ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence
7203+
with pytest.raises(ValueError, match="must have the same number of samples"):
7204+
ts1.concatenate(ts2)
7205+
7206+
70777207
class TestMissingData:
70787208
"""
70797209
Test various aspects of missing data functionality

python/tskit/tables.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3991,6 +3991,96 @@ def trim(self, record_provenance=True):
39913991
record=json.dumps(provenance.get_provenance_dict(parameters))
39923992
)
39933993

3994+
def shift(self, value, *, sequence_length=None, record_provenance=True):
3995+
"""
3996+
Shift the coordinate system (used by edges and sites) of this TableCollection by
3997+
a given value. This is identical to :meth:`TreeSequence.shift` but acts
3998+
*in place* to alter the data in this :class:`TableCollection`.
3999+
4000+
.. note::
4001+
No attempt is made to check that the new coordinate system or sequence length
4002+
is valid: if you wish to do this, use {meth}`TreeSequence.shift` instead.
4003+
4004+
:param value: The amount by which to shift the coordinate system.
4005+
:param sequence_length: The new sequence length of the tree sequence. If
4006+
``None`` (default) add `value` to the sequence length.
4007+
"""
4008+
self.edges.left += value
4009+
self.edges.right += value
4010+
self.sites.position += value
4011+
if sequence_length is None:
4012+
self.sequence_length += value
4013+
else:
4014+
self.sequence_length = sequence_length
4015+
4016+
if record_provenance:
4017+
parameters = {
4018+
"command": "shift",
4019+
"value": value,
4020+
"sequence_length": sequence_length,
4021+
}
4022+
self.provenances.add_row(
4023+
record=json.dumps(provenance.get_provenance_dict(parameters))
4024+
)
4025+
4026+
def concatenate(
4027+
self, other, *, node_mapping=None, record_provenance=True, **kwargs
4028+
):
4029+
"""
4030+
Concatenate another table collection to the right of this one. This
4031+
{meth}`shift`s the other table coordinate rightwards, then calls
4032+
{meth}`union` with ``check_shared_equality=False`` and the provided
4033+
``node_mapping``. If no node mapping is given, the two table
4034+
collections must have the same number of samples, and those are treated
4035+
(in numerical order) as shared between the two table collections.
4036+
This is identical to :meth:`TreeSequence.concatenate` but
4037+
acts *in place* to alter the data in this :class:`TableCollection`.
4038+
4039+
.. note::
4040+
To add gaps between the concatenated tables, use :meth:`shift` before
4041+
concatenating; to remove gaps, use :meth:`trim`.
4042+
4043+
:param TableCollection other: The other table collection to add to the right
4044+
of this one.
4045+
:param list node_mapping: An array of integers of the same length as the number
4046+
of nodes in ``other``, where the _k_'th element gives the id of the node in
4047+
the current table collection corresponding to node _k_ in the other table
4048+
collection (see {meth}`union`). If None (default), only the sample nodes
4049+
between the two node tables, in numerical order, are mapped to each other.
4050+
:param bool record_provenance: If True (default), record details of this call to
4051+
``concatenate`` in the returned tree sequence's provenance information
4052+
(Default: True).
4053+
:param \\**kwargs: Additional keyword arguments to pass to {meth}`union`
4054+
(e.g. ``add_populations``).
4055+
"""
4056+
if node_mapping is None:
4057+
samples = np.where(self.nodes.flags & tskit.NODE_IS_SAMPLE)[0]
4058+
other_samples = np.where(other.nodes.flags & tskit.NODE_IS_SAMPLE)[0]
4059+
if len(other_samples) != len(samples):
4060+
raise ValueError(
4061+
"each `other` must have the same number of samples as `self`"
4062+
)
4063+
node_mapping = np.full(other.nodes.num_rows, tskit.NULL, dtype=np.int32)
4064+
node_mapping[other_samples] = samples
4065+
other.shift(self.sequence_length, record_provenance=False)
4066+
self.sequence_length = other.sequence_length
4067+
# NB: should we use a different default for add_populations?
4068+
self.union(
4069+
other,
4070+
node_mapping=node_mapping,
4071+
check_shared_equality=False, # Needed as checks fail with internal samples
4072+
record_provenance=False,
4073+
**kwargs,
4074+
)
4075+
if record_provenance:
4076+
parameters = {
4077+
"command": "concatenate",
4078+
"TODO": "add concatenate parameters", # tricky as both have provenances
4079+
}
4080+
self.provenances.add_row(
4081+
record=json.dumps(provenance.get_provenance_dict(parameters))
4082+
)
4083+
39944084
def delete_older(self, time):
39954085
"""
39964086
Deletes edge, mutation and migration information at least as old as

python/tskit/trees.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7060,6 +7060,68 @@ def trim(self, record_provenance=True):
70607060
tables.trim(record_provenance)
70617061
return tables.tree_sequence()
70627062

7063+
def shift(self, value, sequence_length=None, record_provenance=True):
7064+
"""
7065+
Shift the coordinate system (used by edges and sites) of this TableCollection by
7066+
a given value. Positive values shift the coordinate system to the right, negative
7067+
values to the left. The sequence length of the tree sequence will be changed by
7068+
``value``, unless ``sequence_length`` is given, in which case this will be used
7069+
for the new sequence length.
7070+
7071+
:param value: The amount by which to shift the coordinate system.
7072+
:param sequence_length: The new sequence length of the tree sequence. If
7073+
``None`` (default) add ``value`` to the sequence length.
7074+
:raises ValueError: If the new coordinate system is invalid (e.g., if
7075+
shifting the coordinate system results in negative coordinates).
7076+
"""
7077+
tables = self.dump_tables()
7078+
tables.shift(
7079+
value=value,
7080+
sequence_length=sequence_length,
7081+
record_provenance=record_provenance,
7082+
)
7083+
try:
7084+
ts = tables.tree_sequence()
7085+
except ValueError as e:
7086+
raise ValueError("Cannot shift due to bad coordinate values") from e
7087+
return ts
7088+
7089+
def concatenate(
7090+
self, other, *, node_mapping=None, record_provenance=True, **kwargs
7091+
):
7092+
"""
7093+
Concatenate another tree sequence to the right of this one. This shifts the
7094+
coordinate system of the other tree sequence rightwards, then calls
7095+
{meth}`union` with the provided ``node_mapping``. If no node mapping
7096+
is given, matches sample nodes only, in numerical order.
7097+
7098+
.. note::
7099+
To add gaps between the concatenated tables, use :meth:`shift` or
7100+
to remove gaps, use :meth:`trim` before concatenating.
7101+
7102+
:param TableCollection other: The other table collection to add to the right
7103+
of this one.
7104+
:param list node_mapping: An array of integers of the same length as the number
7105+
of nodes in ``other``, where the _k_'th element gives the id of the node in
7106+
the current table collection corresponding to node _k_ in the other table
7107+
collection (see :meth:`union`). If None (default), only the sample nodes
7108+
between the two node tables, in numerical order, are mapped to each other.
7109+
:param bool record_provenance: If True (default), record details of this call to
7110+
``concatenate`` in the returned tree sequence's provenance information
7111+
(Default: True).
7112+
:param \\**kwargs: Additional keyword arguments to pass to :meth:`union`,
7113+
e.g. ``add_populations``.
7114+
"""
7115+
7116+
tables = self.dump_tables()
7117+
tables.concatenate(
7118+
other.tables,
7119+
node_mapping=node_mapping,
7120+
record_provenance=record_provenance,
7121+
**kwargs,
7122+
)
7123+
return tables.tree_sequence()
7124+
70637125
def split_edges(self, time, *, flags=None, population=None, metadata=None):
70647126
"""
70657127
Returns a copy of this tree sequence in which we replace any

0 commit comments

Comments
 (0)