Skip to content

Commit a39e503

Browse files
committed
Add shift and concatenate functions
Fixes tskit-dev#3164
1 parent 589a037 commit a39e503

File tree

5 files changed

+318
-0
lines changed

5 files changed

+318
-0
lines changed

docs/python-api.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,12 @@ which perform the same actions but modify the {class}`TableCollection` in place.
262262
TreeSequence.simplify
263263
TreeSequence.subset
264264
TreeSequence.union
265+
TreeSequence.concatenate
265266
TreeSequence.keep_intervals
266267
TreeSequence.delete_intervals
267268
TreeSequence.delete_sites
268269
TreeSequence.trim
270+
TreeSequence.shift
269271
TreeSequence.split_edges
270272
TreeSequence.decapitate
271273
TreeSequence.extend_haplotypes
@@ -750,6 +752,7 @@ a functional way, returning a new tree sequence while leaving the original uncha
750752
TableCollection.keep_intervals
751753
TableCollection.delete_sites
752754
TableCollection.trim
755+
TableCollection.shift
753756
TableCollection.union
754757
TableCollection.delete_older
755758
```

python/CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
associated with each individual as a numpy array.
1313
(:user:`benjeffery`, :pr:`3153`)
1414

15+
- Add ``shift`` method to both ``TableCollection`` and ``TreeSequence`` classes
16+
allowing the coordinate system to be shifted, and ``TreeSequence.concatenate``
17+
so a set of tree sequence can be added to the right of an existing one.
18+
(:user:`hyanwong`, :pr:`3165`, :issue:`3164`)
19+
1520

1621
**Fixes**
1722

python/tests/test_topology.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7074,6 +7074,182 @@ def test_failure_with_migrations(self):
70747074
ts.trim()
70757075

70767076

7077+
class TestShift:
7078+
"""
7079+
Test the shift functionality
7080+
"""
7081+
7082+
@pytest.mark.parametrize("shift", [-0.5, 0, 0.5])
7083+
def test_shift(self, shift):
7084+
ts = tskit.Tree.generate_comb(2, span=2).tree_sequence
7085+
tables = ts.dump_tables()
7086+
tables.delete_intervals([[0, 1]], simplify=False)
7087+
tables.sites.add_row(1.5, "A")
7088+
ts = tables.tree_sequence()
7089+
ts = ts.shift(shift)
7090+
assert ts.sequence_length == 2 + shift
7091+
assert np.min(ts.tables.edges.left) == 1 + shift
7092+
assert np.max(ts.tables.edges.right) == 2 + shift
7093+
assert np.all(ts.tables.sites.position == 1.5 + shift)
7094+
assert len(list(ts.trees())) == ts.num_trees
7095+
7096+
def test_sequence_length(self):
7097+
ts = tskit.Tree.generate_comb(2).tree_sequence
7098+
ts = ts.shift(1, sequence_length=3)
7099+
assert ts.sequence_length == 3
7100+
ts = ts.shift(-1, sequence_length=1)
7101+
assert ts.sequence_length == 1
7102+
7103+
def test_empty(self):
7104+
empty_ts = tskit.TableCollection(1.0).tree_sequence()
7105+
empty_ts = empty_ts.shift(1)
7106+
assert empty_ts.sequence_length == 2
7107+
empty_ts = empty_ts.shift(-1.5)
7108+
assert empty_ts.sequence_length == 0.5
7109+
assert empty_ts.num_nodes == 0
7110+
7111+
def test_migrations(self):
7112+
tables = tskit.Tree.generate_comb(2, span=2).tree_sequence.dump_tables()
7113+
tables.populations.add_row()
7114+
tables.migrations.add_row(0, 1, 0, 0, 0, 0)
7115+
ts = tables.tree_sequence().shift(10)
7116+
assert np.all(ts.tables.migrations.left == 10)
7117+
assert np.all(ts.tables.migrations.right == 11)
7118+
7119+
def test_provenance(self):
7120+
ts = tskit.Tree.generate_comb(2).tree_sequence
7121+
ts = ts.shift(1, record_provenance=False)
7122+
params = json.loads(ts.provenance(-1).record)["parameters"]
7123+
assert params["command"] != "shift"
7124+
ts = ts.shift(1, sequence_length=9)
7125+
params = json.loads(ts.provenance(-1).record)["parameters"]
7126+
assert params["command"] == "shift"
7127+
assert params["value"] == 1
7128+
assert params["sequence_length"] == 9
7129+
7130+
def test_too_negative(self):
7131+
ts = tskit.Tree.generate_comb(2).tree_sequence
7132+
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_SEQUENCE_LENGTH"):
7133+
ts.shift(-1)
7134+
7135+
def test_bad_seq_len(self):
7136+
ts = tskit.Tree.generate_comb(2).tree_sequence
7137+
with pytest.raises(
7138+
tskit.LibraryError, match="TSK_ERR_RIGHT_GREATER_SEQ_LENGTH"
7139+
):
7140+
ts.shift(1, sequence_length=1)
7141+
7142+
7143+
class TestConcatenate:
7144+
def test_simple(self):
7145+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7146+
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
7147+
assert ts1.num_samples == ts2.num_samples
7148+
assert ts1.num_nodes != ts2.num_nodes
7149+
joint_ts = ts1.concatenate(ts2)
7150+
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5
7151+
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
7152+
assert joint_ts.num_samples == ts1.num_samples
7153+
ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim()
7154+
# Have to simplify here, to remove the redundant nodes
7155+
assert ts3.equals(ts1.simplify(), ignore_provenance=True)
7156+
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
7157+
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
7158+
7159+
def test_multiple(self):
7160+
np.random.seed(42)
7161+
ts3 = [
7162+
tskit.Tree.generate_comb(5, span=2).tree_sequence,
7163+
tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence,
7164+
tskit.Tree.generate_star(5, span=5).tree_sequence,
7165+
]
7166+
for i in range(1, len(ts3)):
7167+
# shuffle the sample nodes so they don't have the same IDs
7168+
ts3[i] = ts3[i].subset(np.random.permutation(ts3[i].num_nodes))
7169+
assert not np.all(ts3[0].samples() == ts3[1].samples())
7170+
assert not np.all(ts3[0].samples() == ts3[2].samples())
7171+
assert not np.all(ts3[1].samples() == ts3[2].samples())
7172+
ts = ts3[0].concatenate(*ts3[1:])
7173+
assert ts.sequence_length == sum([t.sequence_length for t in ts3])
7174+
assert ts.num_nodes - ts.num_samples == sum(
7175+
[t.num_nodes - t.num_samples for t in ts3]
7176+
)
7177+
assert np.all(ts.samples() == ts3[0].samples())
7178+
7179+
def test_empty(self):
7180+
empty_ts = tskit.TableCollection(10).tree_sequence()
7181+
ts = empty_ts.concatenate(empty_ts, empty_ts, empty_ts)
7182+
assert ts.num_nodes == 0
7183+
assert ts.sequence_length == 40
7184+
7185+
def test_samples_at_end(self):
7186+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7187+
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
7188+
# reverse the node order
7189+
ts1 = ts1.subset(np.arange(ts1.num_nodes)[::-1])
7190+
assert ts1.num_samples == ts2.num_samples
7191+
assert np.all(ts1.samples() != ts2.samples())
7192+
joint_ts = ts1.concatenate(ts2)
7193+
assert joint_ts.num_samples == ts1.num_samples
7194+
assert np.all(joint_ts.samples() == ts1.samples())
7195+
7196+
def test_internal_samples(self):
7197+
tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables()
7198+
nodes_flags = tables.nodes.flags
7199+
nodes_flags[:] = tskit.NODE_IS_SAMPLE
7200+
nodes_flags[-1] = 0 # Only root is not a sample
7201+
tables.nodes.flags = nodes_flags
7202+
ts = tables.tree_sequence()
7203+
joint_ts = ts.concatenate(ts)
7204+
assert joint_ts.num_samples == ts.num_samples
7205+
assert joint_ts.num_nodes == ts.num_nodes + 1
7206+
assert joint_ts.sequence_length == ts.sequence_length * 2
7207+
7208+
def test_some_shared_samples(self):
7209+
ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence
7210+
ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence
7211+
shared = np.full(ts2.num_nodes, tskit.NULL)
7212+
shared[0] = 1
7213+
shared[1] = 0
7214+
joint_ts = ts1.concatenate(ts2, node_mappings=[shared])
7215+
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
7216+
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
7217+
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
7218+
7219+
def test_provenance(self):
7220+
ts = tskit.Tree.generate_comb(2).tree_sequence
7221+
ts = ts.concatenate(ts, record_provenance=False)
7222+
params = json.loads(ts.provenance(-1).record)["parameters"]
7223+
assert params["command"] != "concatenate"
7224+
7225+
ts = ts.concatenate(ts)
7226+
params = json.loads(ts.provenance(-1).record)["parameters"]
7227+
assert params["command"] == "concatenate"
7228+
7229+
def test_unequal_samples(self):
7230+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7231+
ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence
7232+
with pytest.raises(ValueError, match="must have the same number of samples"):
7233+
ts1.concatenate(ts2)
7234+
7235+
@pytest.mark.skip(
7236+
reason="union bug: https://github.com/tskit-dev/tskit/issues/3168"
7237+
)
7238+
def test_duplicate_ts(self):
7239+
ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence
7240+
ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original
7241+
nm = np.arange(ts.num_nodes) # all nodes identical
7242+
ts2 = ts.concatenate(ts, ts, ts, node_mappings=[nm] * 3, add_populations=False)
7243+
ts2 = ts2.simplify() # squash the edges
7244+
assert ts1.equals(ts2, ignore_provenance=True)
7245+
7246+
def test_node_mappings_bad_len(self):
7247+
ts = tskit.Tree.generate_comb(3, span=2).tree_sequence
7248+
nm = np.arange(ts.num_nodes)
7249+
with pytest.raises(ValueError, match="same number of node_mappings"):
7250+
ts.concatenate(ts, ts, ts, node_mappings=[nm, nm])
7251+
7252+
70777253
class TestMissingData:
70787254
"""
70797255
Test various aspects of missing data functionality

python/tskit/tables.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3991,6 +3991,40 @@ def trim(self, record_provenance=True):
39913991
record=json.dumps(provenance.get_provenance_dict(parameters))
39923992
)
39933993

3994+
def shift(self, value, *, sequence_length=None, record_provenance=True):
3995+
"""
3996+
Shift the coordinate system (used by edges, sites, and migrations) of this
3997+
TableCollection by a given value. This is identical to :meth:`TreeSequence.shift`
3998+
but acts *in place* to alter the data in this :class:`TableCollection`.
3999+
4000+
.. note::
4001+
No attempt is made to check that the new coordinate system or sequence length
4002+
is valid: if you wish to do this, use {meth}`TreeSequence.shift` instead.
4003+
4004+
:param value: The amount by which to shift the coordinate system.
4005+
:param sequence_length: The new sequence length of the tree sequence. If
4006+
``None`` (default) add `value` to the sequence length.
4007+
"""
4008+
self.drop_index()
4009+
self.edges.left += value
4010+
self.edges.right += value
4011+
self.migrations.left += value
4012+
self.migrations.right += value
4013+
self.sites.position += value
4014+
if sequence_length is None:
4015+
self.sequence_length += value
4016+
else:
4017+
self.sequence_length = sequence_length
4018+
if record_provenance:
4019+
parameters = {
4020+
"command": "shift",
4021+
"value": value,
4022+
"sequence_length": sequence_length,
4023+
}
4024+
self.provenances.add_row(
4025+
record=json.dumps(provenance.get_provenance_dict(parameters))
4026+
)
4027+
39944028
def delete_older(self, time):
39954029
"""
39964030
Deletes edge, mutation and migration information at least as old as

python/tskit/trees.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import functools
3333
import io
3434
import itertools
35+
import json
3536
import math
3637
import numbers
3738
import warnings
@@ -46,6 +47,7 @@
4647
import tskit.combinatorics as combinatorics
4748
import tskit.drawing as drawing
4849
import tskit.metadata as metadata_module
50+
import tskit.provenance as provenance
4951
import tskit.tables as tables
5052
import tskit.text_formats as text_formats
5153
import tskit.util as util
@@ -7060,6 +7062,104 @@ def trim(self, record_provenance=True):
70607062
tables.trim(record_provenance)
70617063
return tables.tree_sequence()
70627064

7065+
def shift(self, value, sequence_length=None, record_provenance=True):
7066+
"""
7067+
Shift the coordinate system (used by edges and sites) of this TableCollection by
7068+
a given value. Positive values shift the coordinate system to the right, negative
7069+
values to the left. The sequence length of the tree sequence will be changed by
7070+
``value``, unless ``sequence_length`` is given, in which case this will be used
7071+
for the new sequence length.
7072+
7073+
.. note::
7074+
By setting ``value=0``, this method will simply return a tree sequence
7075+
with a new sequence length.
7076+
7077+
:param value: The amount by which to shift the coordinate system.
7078+
:param sequence_length: The new sequence length of the tree sequence. If
7079+
``None`` (default) add ``value`` to the sequence length.
7080+
:raises ValueError: If the new coordinate system is invalid (e.g., if
7081+
shifting the coordinate system results in negative coordinates).
7082+
"""
7083+
tables = self.dump_tables()
7084+
tables.shift(
7085+
value=value,
7086+
sequence_length=sequence_length,
7087+
record_provenance=record_provenance,
7088+
)
7089+
return tables.tree_sequence()
7090+
7091+
def concatenate(
7092+
self, *args, node_mappings=None, record_provenance=True, add_populations=None
7093+
):
7094+
r"""
7095+
Concatenate a set of tree sequences to the right of this one, by repeatedly
7096+
calling {meth}`union` with an (optional)
7097+
node mapping for each of the ``others``. If any node mapping is ``None``
7098+
only map the sample nodes between the input tree sequence and this one,
7099+
based on the numerical order of sample node IDs.
7100+
7101+
.. note::
7102+
To add gaps between the concatenated tables, use :meth:`shift` or
7103+
to remove gaps, use :meth:`trim` before concatenating.
7104+
7105+
:param TreeSequence \*args: A list of other tree sequences to append to
7106+
the right of this one.
7107+
:param Union[list, None] node_mappings: An list of node mappings for each
7108+
input tree sequence in ``args``. Each should either be an array of
7109+
integers of the same length as the number of nodes in the equivalent
7110+
input tree sequence (see :meth:`union` for details), or ``None``.
7111+
If ``None``, only sample nodes are mapped to each other.
7112+
Default: ``None``, treated as ``[None] * len(args)``.
7113+
:param bool record_provenance: If True (default), record details of this
7114+
call to ``concatenate`` in the returned tree sequence's provenance
7115+
information (Default: True).
7116+
:param bool add_populations: If True (default), nodes new to ``self`` will
7117+
be assigned new population IDs (see :meth:`union`)
7118+
"""
7119+
if node_mappings is None:
7120+
node_mappings = [None] * len(args)
7121+
if add_populations is None:
7122+
add_populations = True
7123+
if len(node_mappings) != len(args):
7124+
raise ValueError(
7125+
"You must provide the same number of node_mappings as args"
7126+
)
7127+
7128+
samples = self.samples()
7129+
tables = self.dump_tables()
7130+
tables.drop_index()
7131+
7132+
for node_mapping, other in zip(node_mappings, args):
7133+
if node_mapping is None:
7134+
other_samples = other.samples()
7135+
if len(other_samples) != len(samples):
7136+
raise ValueError(
7137+
"each `other` must have the same number of samples as `self`"
7138+
)
7139+
node_mapping = np.full(other.num_nodes, tskit.NULL, dtype=np.int32)
7140+
node_mapping[other_samples] = samples
7141+
other_tables = other.dump_tables()
7142+
other_tables.shift(tables.sequence_length, record_provenance=False)
7143+
tables.sequence_length = other_tables.sequence_length
7144+
# NB: should we use a different default for add_populations?
7145+
tables.union(
7146+
other_tables,
7147+
node_mapping=node_mapping,
7148+
check_shared_equality=False, # Else checks fail with internal samples
7149+
record_provenance=False,
7150+
add_populations=add_populations,
7151+
)
7152+
if record_provenance:
7153+
parameters = {
7154+
"command": "concatenate",
7155+
"TODO": "add concatenate parameters", # tricky as both have provenances
7156+
}
7157+
tables.provenances.add_row(
7158+
record=json.dumps(provenance.get_provenance_dict(parameters))
7159+
)
7160+
7161+
return tables.tree_sequence()
7162+
70637163
def split_edges(self, time, *, flags=None, population=None, metadata=None):
70647164
"""
70657165
Returns a copy of this tree sequence in which we replace any

0 commit comments

Comments
 (0)