@@ -7074,6 +7074,182 @@ def test_failure_with_migrations(self):
7074
7074
ts .trim ()
7075
7075
7076
7076
7077
+ class TestShift :
7078
+ """
7079
+ Test the shift functionality
7080
+ """
7081
+
7082
+ @pytest .mark .parametrize ("shift" , [- 0.5 , 0 , 0.5 ])
7083
+ def test_shift (self , shift ):
7084
+ ts = tskit .Tree .generate_comb (2 , span = 2 ).tree_sequence
7085
+ tables = ts .dump_tables ()
7086
+ tables .delete_intervals ([[0 , 1 ]], simplify = False )
7087
+ tables .sites .add_row (1.5 , "A" )
7088
+ ts = tables .tree_sequence ()
7089
+ ts = ts .shift (shift )
7090
+ assert ts .sequence_length == 2 + shift
7091
+ assert np .min (ts .tables .edges .left ) == 1 + shift
7092
+ assert np .max (ts .tables .edges .right ) == 2 + shift
7093
+ assert np .all (ts .tables .sites .position == 1.5 + shift )
7094
+ assert len (list (ts .trees ())) == ts .num_trees
7095
+
7096
+ def test_sequence_length (self ):
7097
+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7098
+ ts = ts .shift (1 , sequence_length = 3 )
7099
+ assert ts .sequence_length == 3
7100
+ ts = ts .shift (- 1 , sequence_length = 1 )
7101
+ assert ts .sequence_length == 1
7102
+
7103
+ def test_empty (self ):
7104
+ empty_ts = tskit .TableCollection (1.0 ).tree_sequence ()
7105
+ empty_ts = empty_ts .shift (1 )
7106
+ assert empty_ts .sequence_length == 2
7107
+ empty_ts = empty_ts .shift (- 1.5 )
7108
+ assert empty_ts .sequence_length == 0.5
7109
+ assert empty_ts .num_nodes == 0
7110
+
7111
+ def test_migrations (self ):
7112
+ tables = tskit .Tree .generate_comb (2 , span = 2 ).tree_sequence .dump_tables ()
7113
+ tables .populations .add_row ()
7114
+ tables .migrations .add_row (0 , 1 , 0 , 0 , 0 , 0 )
7115
+ ts = tables .tree_sequence ().shift (10 )
7116
+ assert np .all (ts .tables .migrations .left == 10 )
7117
+ assert np .all (ts .tables .migrations .right == 11 )
7118
+
7119
+ def test_provenance (self ):
7120
+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7121
+ ts = ts .shift (1 , record_provenance = False )
7122
+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7123
+ assert params ["command" ] != "shift"
7124
+ ts = ts .shift (1 , sequence_length = 9 )
7125
+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7126
+ assert params ["command" ] == "shift"
7127
+ assert params ["value" ] == 1
7128
+ assert params ["sequence_length" ] == 9
7129
+
7130
+ def test_too_negative (self ):
7131
+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7132
+ with pytest .raises (tskit .LibraryError , match = "TSK_ERR_BAD_SEQUENCE_LENGTH" ):
7133
+ ts .shift (- 1 )
7134
+
7135
+ def test_bad_seq_len (self ):
7136
+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7137
+ with pytest .raises (
7138
+ tskit .LibraryError , match = "TSK_ERR_RIGHT_GREATER_SEQ_LENGTH"
7139
+ ):
7140
+ ts .shift (1 , sequence_length = 1 )
7141
+
7142
+
7143
+ class TestConcatenate :
7144
+ def test_simple (self ):
7145
+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7146
+ ts2 = tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence
7147
+ assert ts1 .num_samples == ts2 .num_samples
7148
+ assert ts1 .num_nodes != ts2 .num_nodes
7149
+ joint_ts = ts1 .concatenate (ts2 )
7150
+ assert joint_ts .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 5
7151
+ assert joint_ts .sequence_length == ts1 .sequence_length + ts2 .sequence_length
7152
+ assert joint_ts .num_samples == ts1 .num_samples
7153
+ ts3 = joint_ts .delete_intervals ([[2 , 5 ]]).rtrim ()
7154
+ # Have to simplify here, to remove the redundant nodes
7155
+ assert ts3 .equals (ts1 .simplify (), ignore_provenance = True )
7156
+ ts4 = joint_ts .delete_intervals ([[0 , 2 ]]).ltrim ()
7157
+ assert ts4 .equals (ts2 .simplify (), ignore_provenance = True )
7158
+
7159
+ def test_multiple (self ):
7160
+ np .random .seed (42 )
7161
+ ts3 = [
7162
+ tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence ,
7163
+ tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence ,
7164
+ tskit .Tree .generate_star (5 , span = 5 ).tree_sequence ,
7165
+ ]
7166
+ for i in range (1 , len (ts3 )):
7167
+ # shuffle the sample nodes so they don't have the same IDs
7168
+ ts3 [i ] = ts3 [i ].subset (np .random .permutation (ts3 [i ].num_nodes ))
7169
+ assert not np .all (ts3 [0 ].samples () == ts3 [1 ].samples ())
7170
+ assert not np .all (ts3 [0 ].samples () == ts3 [2 ].samples ())
7171
+ assert not np .all (ts3 [1 ].samples () == ts3 [2 ].samples ())
7172
+ ts = ts3 [0 ].concatenate (* ts3 [1 :])
7173
+ assert ts .sequence_length == sum ([t .sequence_length for t in ts3 ])
7174
+ assert ts .num_nodes - ts .num_samples == sum (
7175
+ [t .num_nodes - t .num_samples for t in ts3 ]
7176
+ )
7177
+ assert np .all (ts .samples () == ts3 [0 ].samples ())
7178
+
7179
+ def test_empty (self ):
7180
+ empty_ts = tskit .TableCollection (10 ).tree_sequence ()
7181
+ ts = empty_ts .concatenate (empty_ts , empty_ts , empty_ts )
7182
+ assert ts .num_nodes == 0
7183
+ assert ts .sequence_length == 40
7184
+
7185
+ def test_samples_at_end (self ):
7186
+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7187
+ ts2 = tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence
7188
+ # reverse the node order
7189
+ ts1 = ts1 .subset (np .arange (ts1 .num_nodes )[::- 1 ])
7190
+ assert ts1 .num_samples == ts2 .num_samples
7191
+ assert np .all (ts1 .samples () != ts2 .samples ())
7192
+ joint_ts = ts1 .concatenate (ts2 )
7193
+ assert joint_ts .num_samples == ts1 .num_samples
7194
+ assert np .all (joint_ts .samples () == ts1 .samples ())
7195
+
7196
+ def test_internal_samples (self ):
7197
+ tables = tskit .Tree .generate_comb (4 , span = 2 ).tree_sequence .dump_tables ()
7198
+ nodes_flags = tables .nodes .flags
7199
+ nodes_flags [:] = tskit .NODE_IS_SAMPLE
7200
+ nodes_flags [- 1 ] = 0 # Only root is not a sample
7201
+ tables .nodes .flags = nodes_flags
7202
+ ts = tables .tree_sequence ()
7203
+ joint_ts = ts .concatenate (ts )
7204
+ assert joint_ts .num_samples == ts .num_samples
7205
+ assert joint_ts .num_nodes == ts .num_nodes + 1
7206
+ assert joint_ts .sequence_length == ts .sequence_length * 2
7207
+
7208
+ def test_some_shared_samples (self ):
7209
+ ts1 = tskit .Tree .generate_comb (4 , span = 2 ).tree_sequence
7210
+ ts2 = tskit .Tree .generate_balanced (8 , arity = 3 , span = 3 ).tree_sequence
7211
+ shared = np .full (ts2 .num_nodes , tskit .NULL )
7212
+ shared [0 ] = 1
7213
+ shared [1 ] = 0
7214
+ joint_ts = ts1 .concatenate (ts2 , node_mappings = [shared ])
7215
+ assert joint_ts .sequence_length == ts1 .sequence_length + ts2 .sequence_length
7216
+ assert joint_ts .num_samples == ts1 .num_samples + ts2 .num_samples - 2
7217
+ assert joint_ts .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 2
7218
+
7219
+ def test_provenance (self ):
7220
+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7221
+ ts = ts .concatenate (ts , record_provenance = False )
7222
+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7223
+ assert params ["command" ] != "concatenate"
7224
+
7225
+ ts = ts .concatenate (ts )
7226
+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7227
+ assert params ["command" ] == "concatenate"
7228
+
7229
+ def test_unequal_samples (self ):
7230
+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7231
+ ts2 = tskit .Tree .generate_balanced (4 , arity = 3 , span = 3 ).tree_sequence
7232
+ with pytest .raises (ValueError , match = "must have the same number of samples" ):
7233
+ ts1 .concatenate (ts2 )
7234
+
7235
+ @pytest .mark .skip (
7236
+ reason = "union bug: https://github.com/tskit-dev/tskit/issues/3168"
7237
+ )
7238
+ def test_duplicate_ts (self ):
7239
+ ts1 = tskit .Tree .generate_comb (3 , span = 4 ).tree_sequence
7240
+ ts = ts1 .keep_intervals ([[0 , 1 ]]).trim () # a quarter of the original
7241
+ nm = np .arange (ts .num_nodes ) # all nodes identical
7242
+ ts2 = ts .concatenate (ts , ts , ts , node_mappings = [nm ] * 3 , add_populations = False )
7243
+ ts2 = ts2 .simplify () # squash the edges
7244
+ assert ts1 .equals (ts2 , ignore_provenance = True )
7245
+
7246
+ def test_node_mappings_bad_len (self ):
7247
+ ts = tskit .Tree .generate_comb (3 , span = 2 ).tree_sequence
7248
+ nm = np .arange (ts .num_nodes )
7249
+ with pytest .raises (ValueError , match = "same number of node_mappings" ):
7250
+ ts .concatenate (ts , ts , ts , node_mappings = [nm , nm ])
7251
+
7252
+
7077
7253
class TestMissingData :
7078
7254
"""
7079
7255
Test various aspects of missing data functionality
0 commit comments