diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 251ddaccdf..0f7a2d2ac3 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2281,7 +2281,7 @@ test_paper_ex_two_site(void) result_size = num_sites * num_sites; tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_one_set); @@ -2295,7 +2295,7 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_two_sets); @@ -2309,7 +2309,7 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan( @@ -2320,6 +2320,128 @@ test_paper_ex_two_site(void) tsk_safe_free(col_sites); } +static void +test_paper_ex_two_branch(void) +{ + int ret; + tsk_treeseq_t ts; + double result[27]; + tsk_size_t i, result_size, num_sample_sets; + tsk_flags_t options = 0; + double truth_one_set[9] + = { 0.001066666666666695, -0.00012666666666665688, -0.0001266666666666534, + -0.00012666666666665688, 6.016666666665456e-05, 6.016666666665629e-05, + -0.0001266666666666534, 6.016666666665629e-05, 6.016666666665629e-05 }; + double truth_two_sets[18] + = { 0.001066666666666695, 0.001066666666666695, -0.00012666666666665688, + -0.00012666666666665688, -0.0001266666666666534, -0.0001266666666666534, + -0.00012666666666665688, -0.00012666666666665688, 6.016666666665456e-05, + 6.016666666665456e-05, 6.016666666665629e-05, 6.016666666665629e-05, + -0.0001266666666666534, -0.0001266666666666534, 6.016666666665629e-05, + 6.016666666665629e-05, 6.016666666665629e-05, 6.016666666665629e-05 }; + double truth_three_sets[27] = { 0.001066666666666695, 0.001066666666666695, NAN, + -0.00012666666666665688, -0.00012666666666665688, NAN, -0.0001266666666666534, + -0.0001266666666666534, NAN, -0.00012666666666665688, -0.00012666666666665688, + NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665629e-05, + 6.016666666665629e-05, NAN, -0.0001266666666666534, -0.0001266666666666534, NAN, + 6.016666666665629e-05, 6.016666666665629e-05, NAN, 6.016666666665629e-05, + 6.016666666665629e-05, NAN }; + double truth_positions_subset_1[12] = { 0.001066666666666695, 0.001066666666666695, + NAN, 0.001066666666666695, 0.001066666666666695, NAN, 0.001066666666666695, + 0.001066666666666695, NAN, 0.001066666666666695, 0.001066666666666695, NAN }; + double truth_positions_subset_2[12] = { 6.016666666665456e-05, 6.016666666665456e-05, + NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, + 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; + double truth_positions_subset_3[12] = { 6.016666666665456e-05, 6.016666666665456e-05, + NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, + 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + tsk_size_t sample_set_sizes[3]; + tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t num_trees = ts.num_trees; + double *row_positions = tsk_malloc(num_trees * sizeof(*row_positions)); + double *col_positions = tsk_malloc(num_trees * sizeof(*col_positions)); + double positions_subset_1[2] = { 0., 0.1 }; + double positions_subset_2[2] = { 2., 6. }; + double positions_subset_3[2] = { 9., 9.999 }; + + // First sample set contains all of the samples + sample_set_sizes[0] = ts.num_samples; + num_sample_sets = 1; + for (i = 0; i < ts.num_samples; i++) { + sample_sets[i] = (tsk_id_t) i; + } + for (i = 0; i < num_trees; i++) { + row_positions[i] = ts.breakpoints[i]; + col_positions[i] = ts.breakpoints[i]; + } + + options |= TSK_STAT_BRANCH; + + result_size = num_trees * num_trees * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_one_set); + + // Second sample set contains all of the samples + sample_set_sizes[1] = ts.num_samples; + num_sample_sets = 2; + for (i = ts.num_samples; i < ts.num_samples * 2; i++) { + sample_sets[i] = (tsk_id_t) i - (tsk_id_t) ts.num_samples; + } + + result_size = num_trees * num_trees * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_two_sets); + + // Third sample set contains the first two samples + sample_set_sizes[2] = 2; + num_sample_sets = 3; + for (i = ts.num_samples * 2; i < (ts.num_samples * 3) - 2; i++) { + sample_sets[i] = (tsk_id_t) i - (tsk_id_t) ts.num_samples * 2; + } + + result_size = num_trees * num_trees * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_three_sets); + + result_size = 4 * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + NULL, positions_subset_1, 2, NULL, positions_subset_1, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_1); + + result_size = 4 * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + NULL, positions_subset_2, 2, NULL, positions_subset_2, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_2); + + result_size = 4 * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + NULL, positions_subset_3, 2, NULL, positions_subset_3, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_3); + + tsk_treeseq_free(&ts); + tsk_safe_free(row_positions); + tsk_safe_free(col_positions); +} + static void test_two_site_correlated_multiallelic(void) { @@ -2401,43 +2523,43 @@ test_two_site_correlated_multiallelic(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, num_sites, col_sites, 0, result); + num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); @@ -2532,43 +2654,43 @@ test_two_site_uncorrelated_multiallelic(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, num_sites, col_sites, 0, result); + num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); @@ -2637,7 +2759,7 @@ test_two_site_backmutation(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); @@ -2675,7 +2797,7 @@ test_paper_ex_two_site_subset(void) result_size = 2 * 2; tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - row_sites, 2, col_sites, 0, result); + row_sites, NULL, 2, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_1); @@ -2683,7 +2805,7 @@ test_paper_ex_two_site_subset(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); col_sites[0] = 2; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, - row_sites, 1, col_sites, 0, result); + row_sites, NULL, 1, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_2); @@ -2694,7 +2816,7 @@ test_paper_ex_two_site_subset(void) col_sites[0] = 0; col_sites[1] = 1; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - row_sites, 2, col_sites, 0, result); + row_sites, NULL, 2, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_3); @@ -2716,8 +2838,9 @@ test_two_locus_stat_input_errors(void) tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_size_t num_sample_sets = 1; tsk_id_t sample_sets[ts.num_samples]; - tsk_size_t result_size = num_sites * num_sites; - double result[result_size]; + double positions[10] = { 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 }; + double bad_col_positions[2] = { 0., 0. }; // used in 1 test to cover column check + double result[100]; tsk_size_t s; for (s = 0; s < ts.num_samples; s++) { @@ -2736,66 +2859,100 @@ test_two_locus_stat_input_errors(void) sample_sets[1] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); sample_sets[1] = 1; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + row_sites, NULL, num_sites, col_sites, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, + result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, TSK_STAT_BRANCH, result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); - ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, num_sites, row_sites, - num_sites, col_sites, 0, result); + NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); sample_set_sizes[0] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EMPTY_SAMPLE_SET); sample_set_sizes[0] = ts.num_samples; sample_sets[1] = 10; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); sample_sets[1] = 1; row_sites[0] = 1000; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); row_sites[0] = 0; col_sites[num_sites - 1] = (tsk_id_t) num_sites; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); col_sites[num_sites - 1] = (tsk_id_t) num_sites - 1; row_sites[0] = 1; row_sites[1] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_UNSORTED_SITES); row_sites[0] = 0; row_sites[1] = 1; row_sites[0] = 1; row_sites[1] = 1; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_DUPLICATE_SITES); row_sites[0] = 0; row_sites[1] = 1; // Not an error condition, but we want to record this behavior - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, result); - CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, + NULL, 0, NULL, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + positions[9] = 1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POSITION_OUT_OF_BOUNDS); + positions[9] = 0.9; + + positions[0] = -0.1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POSITION_OUT_OF_BOUNDS); + positions[0] = 0; + + positions[0] = 0.1; + positions[1] = 0; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_UNSORTED_POSITIONS); + positions[0] = 0; + positions[1] = 0.1; + + // rows always fail first, check columns + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 2, NULL, bad_col_positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_DUPLICATE_POSITIONS); + + positions[0] = 0; + positions[1] = 0; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_DUPLICATE_POSITIONS); + positions[0] = 0; + positions[1] = 0.1; + + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); tsk_treeseq_free(&ts); tsk_safe_free(row_sites); @@ -3101,6 +3258,7 @@ main(int argc, char **argv) { "test_ld_silent_mutations", test_ld_silent_mutations }, { "test_paper_ex_two_site", test_paper_ex_two_site }, + { "test_paper_ex_two_branch", test_paper_ex_two_branch }, { "test_two_site_correlated_multiallelic", test_two_site_correlated_multiallelic }, { "test_two_site_uncorrelated_multiallelic", diff --git a/c/tskit/core.c b/c/tskit/core.c index 32120ba2f8..609979edd8 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -226,6 +226,9 @@ tsk_strerror_internal(int err) ret = "One of the kept rows in the table refers to a deleted row. " "(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED)"; break; + case TSK_ERR_POSITION_OUT_OF_BOUNDS: + ret = "Position out of bounds. (TSK_ERR_POSITION_OUT_OF_BOUNDS)"; + break; /* Edge errors */ case TSK_ERR_NULL_PARENT: @@ -502,6 +505,24 @@ tsk_strerror_internal(int err) ret = "Times must be strictly increasing. (TSK_ERR_UNSORTED_TIMES)"; break; + /* Two locus errors */ + case TSK_ERR_STAT_UNSORTED_POSITIONS: + ret = "The provided positions are not sorted in strictly increasing " + "order. (TSK_ERR_STAT_UNSORTED_POSITIONS)"; + break; + case TSK_ERR_STAT_DUPLICATE_POSITIONS: + ret = "The provided positions contain duplicates. " + "(TSK_ERR_STAT_DUPLICATE_POSITIONS)"; + break; + case TSK_ERR_STAT_UNSORTED_SITES: + ret = "The provided sites are not sorted in strictly increasing position " + "order. (TSK_ERR_STAT_UNSORTED_SITES)"; + break; + case TSK_ERR_STAT_DUPLICATE_SITES: + ret = "The provided sites contain duplicated entries. " + "(TSK_ERR_STAT_DUPLICATE_SITES)"; + break; + /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: ret = "Must provide at least one non-missing genotype. " diff --git a/c/tskit/core.h b/c/tskit/core.h index 641400a44c..93e407fe68 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -370,6 +370,11 @@ One of the rows in the retained table refers to a row that has been deleted. */ #define TSK_ERR_KEEP_ROWS_MAP_TO_DELETED -212 +/** +A genomic position was less than zero or greater equal to the sequence +length +*/ +#define TSK_ERR_POSITION_OUT_OF_BOUNDS -213 /** @} */ @@ -710,6 +715,22 @@ The vector of quantiles is out of bounds or in nonascending order. Times are not in ascending order */ #define TSK_ERR_UNSORTED_TIMES -917 +/* +The provided positions are not provided in strictly increasing order +*/ +#define TSK_ERR_STAT_UNSORTED_POSITIONS -918 +/** +The provided positions are not unique +*/ +#define TSK_ERR_STAT_DUPLICATE_POSITIONS -919 +/** +The provided sites are not provided in strictly increasing position order +*/ +#define TSK_ERR_STAT_UNSORTED_SITES -920 +/** +The provided sites are not unique +*/ +#define TSK_ERR_STAT_DUPLICATE_SITES -921 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1432d257ee..906e9f5c0c 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -808,7 +808,7 @@ tsk_treeseq_get_individuals_time(const tsk_treeseq_t *self, double *output) /* Stats functions */ -#define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t) row)) +#define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t)(row))) static inline double * GET_3D_ROW(double *base, tsk_size_t num_nodes, tsk_size_t output_dim, @@ -2621,9 +2621,12 @@ check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_row ret = TSK_ERR_SITE_OUT_OF_BOUNDS; goto out; } - if (sites[i] >= sites[i + 1]) { - // TODO: this checks no repeats, but error is ambiguous - ret = TSK_ERR_UNSORTED_SITES; + if (sites[i] > sites[i + 1]) { + ret = TSK_ERR_STAT_UNSORTED_SITES; + goto out; + } + if (sites[i] == sites[i + 1]) { + ret = TSK_ERR_STAT_DUPLICATE_SITES; goto out; } } @@ -2636,12 +2639,564 @@ check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_row return ret; } +static int +check_positions( + const double *positions, tsk_size_t num_positions, double sequence_length) +{ + int ret = 0; + tsk_size_t i; + + if (num_positions == 0) { + return ret; // No need to verify positions if there aren't any + } + + for (i = 0; i < num_positions - 1; i++) { + if (positions[i] < 0 || positions[i] >= sequence_length) { + ret = TSK_ERR_POSITION_OUT_OF_BOUNDS; + goto out; + } + if (positions[i] > positions[i + 1]) { + ret = TSK_ERR_STAT_UNSORTED_POSITIONS; + goto out; + } + if (positions[i] == positions[i + 1]) { + ret = TSK_ERR_STAT_DUPLICATE_POSITIONS; + goto out; + } + } + // check bounds of last value + if (positions[i] < 0 || positions[i] >= sequence_length) { + ret = TSK_ERR_POSITION_OUT_OF_BOUNDS; + goto out; + } +out: + return ret; +} + +static int +positions_to_tree_indexes(const tsk_treeseq_t *ts, const double *positions, + tsk_size_t num_positions, tsk_id_t **tree_indexes) +{ + int ret = 0; + tsk_id_t tree_index = 0; + tsk_size_t i, num_trees = ts->num_trees; + + *tree_indexes = tsk_malloc(num_positions * sizeof(*tree_indexes)); + if (tree_indexes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tsk_memset(*tree_indexes, TSK_NULL, num_positions * sizeof(**tree_indexes)); + for (i = 0; i < num_positions; i++) { + while (ts->breakpoints[tree_index + 1] <= positions[i]) { + tree_index++; + } + (*tree_indexes)[i] = tree_index; + } + tsk_bug_assert(tree_index <= (tsk_id_t)(num_trees - 1)); + +out: + return ret; +} + +static int +get_index_counts( + const tsk_id_t *indexes, tsk_size_t num_indexes, tsk_size_t **out_counts) +{ + int ret = 0; + tsk_id_t index = indexes[0]; + tsk_size_t count, i; + tsk_size_t *counts = tsk_calloc( + (tsk_size_t)(indexes[num_indexes - 1] - indexes[0] + 1), sizeof(*counts)); + if (counts == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + count = 1; + for (i = 1; i < num_indexes; i++) { + if (indexes[i] == indexes[i - 1]) { + count++; + } else { + counts[index - indexes[0]] = count; + count = 1; + index = indexes[i]; + } + } + counts[index - indexes[0]] = count; + *out_counts = counts; +out: + return ret; +} + +typedef struct { + tsk_tree_t *tree; + tsk_bit_array_t *node_samples; + tsk_id_t *parent; + tsk_id_t *edges_out; + tsk_id_t *edges_in; + double *branch_len; + tsk_size_t n_edges_out; + tsk_size_t n_edges_in; +} iter_state; + +static int +iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) +{ + int ret = 0; + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; + + self->tree = tsk_malloc(sizeof(*self->tree)); + self->node_samples = tsk_calloc(1, sizeof(*self->node_samples)); + ret = tsk_tree_init(self->tree, ts, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { + goto out; + } + ret = tsk_bit_array_init(self->node_samples, ts->num_samples, state_dim * num_nodes); + if (ret != 0) { + goto out; + } + self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); + self->edges_out = tsk_malloc(num_nodes * sizeof(*self->edges_out)); + self->edges_in = tsk_malloc(num_nodes * sizeof(*self->edges_in)); + self->branch_len = tsk_calloc(num_nodes, sizeof(*self->branch_len)); + if (self->parent == NULL || self->edges_out == NULL || self->edges_in == NULL + || self->branch_len == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_bit_array_t *node_samples) +{ + int ret = 0; + tsk_size_t n, k; + tsk_bit_array_t sample_set_row, node_samples_row; + tsk_size_t num_nodes = ts->tables->nodes.num_rows; + tsk_bit_array_value_t sample; + const tsk_id_t *restrict sample_index_map = ts->sample_index_map; + const tsk_flags_t *restrict flags = ts->tables->nodes.flags; + + ret = tsk_bit_array_init(node_samples, ts->num_samples, num_nodes * state_dim); + if (ret != 0) { + goto out; + } + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row(sample_sets, k, &sample_set_row); + for (n = 0; n < num_nodes; n++) { + if (flags[n] & TSK_NODE_IS_SAMPLE) { + sample = (tsk_bit_array_value_t) sample_index_map[n]; + if (tsk_bit_array_contains(&sample_set_row, sample)) { + tsk_bit_array_get_row( + node_samples, (state_dim * n) + k, &node_samples_row); + tsk_bit_array_add_bit(&node_samples_row, sample); + } + } + } + } +out: + return ret; +} + +static void +iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, + const tsk_bit_array_t *node_samples) +{ + self->n_edges_out = 0; + self->n_edges_in = 0; + tsk_tree_clear(self->tree); + tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); + tsk_memset(self->edges_out, TSK_NULL, num_nodes * sizeof(*self->edges_out)); + tsk_memset(self->edges_in, TSK_NULL, num_nodes * sizeof(*self->edges_in)); + tsk_memset(self->branch_len, 0, num_nodes * sizeof(*self->branch_len)); + tsk_memcpy(self->node_samples->data, node_samples->data, + node_samples->size * state_dim * num_nodes * sizeof(*node_samples->data)); +} + +static void +iter_state_free(iter_state *self) +{ + tsk_tree_free(self->tree); + tsk_bit_array_free(self->node_samples); + tsk_safe_free(self->tree); + tsk_safe_free(self->node_samples); + tsk_safe_free(self->parent); + tsk_safe_free(self->edges_out); + tsk_safe_free(self->edges_in); + tsk_safe_free(self->branch_len); +} + +static int +advance_collect_edges(iter_state *s, tsk_id_t index) +{ + int ret = 0; + tsk_id_t j, e; + tsk_size_t i; + double left, right; + tsk_tree_position_t pos; + tsk_tree_t *tree = s->tree; + const double *restrict edge_left = s->tree->tree_sequence->tables->edges.left; + const double *restrict edge_right = s->tree->tree_sequence->tables->edges.right; + + // Either we're seeking forward one step from some nonzero position in the tree, or + // from the beginning of the tree sequence. + if (tree->index != TSK_NULL || index == 0) { + ret = tsk_tree_next(tree); + if (ret < 0) { + goto out; + } + pos = tree->tree_pos; + i = 0; + for (j = pos.out.start; j != pos.out.stop; j++) { + s->edges_out[i] = pos.out.order[j]; + i++; + } + s->n_edges_out = i; + i = 0; + for (j = pos.in.start; j != pos.in.stop; j++) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + s->n_edges_in = i; + } else { + // Seek from an arbitrary nonzero position from an uninitialized tree. + tsk_bug_assert(tree->index == -1); + ret = tsk_tree_seek_index(tree, index, 0); + if (ret < 0) { + goto out; + } + pos = tree->tree_pos; + i = 0; + if (pos.direction == TSK_DIR_FORWARD) { + left = pos.interval.left; + for (j = pos.in.start; j != pos.in.stop; j++) { + e = pos.in.order[j]; + if (edge_left[e] <= left && left < edge_right[e]) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + } + } else { + right = pos.interval.right; + for (j = pos.in.start; j != pos.in.stop; j--) { + e = pos.in.order[j]; + if (edge_right[e] >= right && right > edge_left[e]) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + } + } + s->n_edges_out = 0; + s->n_edges_in = i; + } + ret = 0; +out: + return ret; +} + +static int +compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, + tsk_bit_array_t *child_samples, const iter_state *A_state, const iter_state *B_state, + tsk_size_t state_dim, tsk_size_t result_dim, int sign, general_stat_func_t *f, + sample_count_stat_params_t *f_params, double *result) +{ + int ret = 0; + double a_len, b_len; + double *restrict B_branch_len = B_state->branch_len; + // TODO: is this early return okay? + b_len = B_branch_len[c] * sign; + if (b_len == 0) { + return ret; + } + double *weights_row; + tsk_size_t n, k, a_row, b_row; + tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; + const double *restrict A_branch_len = A_state->branch_len; + const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; + const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; + tsk_size_t num_samples = ts->num_samples; + tsk_size_t num_nodes = ts->tables->nodes.num_rows; + double *weights = tsk_calloc(3 * state_dim, sizeof(*weights)); + double *result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); + + tsk_memset(&AB_samples, 0, sizeof(AB_samples)); + tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); + + if (weights == NULL || result_tmp == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + if (ret != 0) { + goto out; + } + ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); + if (ret != 0) { + goto out; + } + + for (n = 0; n < num_nodes; n++) { + a_len = A_branch_len[n]; + if (a_len == 0) { + continue; + } + for (k = 0; k < state_dim; k++) { + a_row = (state_dim * n) + k; + // TODO: what if c is TSK_NULL? + b_row = (state_dim * (tsk_size_t) c) + k; + tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); + tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); + tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); + weights_row = GET_2D_ROW(weights, 3, k); + weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + weights_row[1] + = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + weights_row[2] + = (double) tsk_bit_array_count(&B_samples) - weights_row[0]; // w_aB + } + ret = f(state_dim, weights, result_dim, result_tmp, f_params); + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + result[k] += result_tmp[k] * a_len * b_len; + } + + if (child_samples != NULL) { + for (k = 0; k < state_dim; k++) { + a_row = (state_dim * n) + k; + // TODO: what if c is TSK_NULL? + b_row = (state_dim * (tsk_size_t) c) + k; + tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); + tsk_bit_array_add(&B_samples_tmp, &B_samples); + tsk_bit_array_subtract(&B_samples_tmp, child_samples); + tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); + tsk_bit_array_intersect(&A_samples, &B_samples_tmp, &AB_samples); + weights_row = GET_2D_ROW(weights, 3, k); + weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + weights_row[1] + = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + weights_row[2] = (double) tsk_bit_array_count(&B_samples_tmp) + - weights_row[0]; // w_aB + } + ret = f(state_dim, weights, result_dim, result_tmp, f_params); + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + result[k] -= result_tmp[k] * a_len * b_len; + } + } + } +out: + tsk_safe_free(weights); + tsk_safe_free(result_tmp); + tsk_bit_array_free(&AB_samples); + tsk_bit_array_free(&B_samples_tmp); + return ret; +} + +static int +compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, + iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, + tsk_size_t result_dim, tsk_size_t state_dim, double *result) +{ + int ret = 0; + tsk_id_t e, c, p; + tsk_size_t j, k; + tsk_bit_array_t child_samples, child_samples_row, samples_row, *in_parent; + const double *restrict time = ts->tables->nodes.time; + const tsk_id_t *restrict edges_child = ts->tables->edges.child; + const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; + tsk_bit_array_t *r_samples = r_state->node_samples; + + tsk_memset(&child_samples, 0, sizeof(child_samples)); + ret = tsk_bit_array_init(&child_samples, ts->num_samples, state_dim); + if (ret != 0) { + goto out; + } + for (j = 0; j < r_state->n_edges_out; j++) { + e = r_state->edges_out[j]; + c = edges_child[e]; + p = edges_parent[e]; + tsk_memset(child_samples.data, 0, + child_samples.size * state_dim * sizeof(tsk_bit_array_value_t)); + tsk_bug_assert(c != TSK_NULL); // TODO: are these checks necessary? + tsk_bug_assert(p != TSK_NULL); + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_add(&child_samples_row, &samples_row); + } + in_parent = NULL; + while (p != TSK_NULL) { + compute_two_tree_branch_state_update(ts, c, in_parent, l_state, r_state, + state_dim, result_dim, -1, f, f_params, result); + if (in_parent != NULL) { + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_subtract(&samples_row, &child_samples_row); + } + } + in_parent = &child_samples; + c = p; + p = r_state->parent[p]; + } + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_subtract(&samples_row, &child_samples_row); + } + c = edges_child[e]; + r_state->branch_len[c] = 0; + r_state->parent[c] = TSK_NULL; + } + for (j = 0; j < r_state->n_edges_in; j++) { + e = r_state->edges_in[j]; + c = edges_child[e]; + p = edges_parent[e]; + tsk_memset(child_samples.data, 0, + child_samples.size * state_dim * sizeof(tsk_bit_array_value_t)); + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_add(&child_samples_row, &samples_row); + } + r_state->branch_len[c] = time[p] - time[c]; + r_state->parent[c] = p; + + in_parent = NULL; + while (p != TSK_NULL) { + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) p) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_add(&samples_row, &child_samples_row); + } + compute_two_tree_branch_state_update(ts, c, in_parent, l_state, r_state, + state_dim, result_dim, +1, f, f_params, result); + in_parent = &child_samples; + c = p; + p = r_state->parent[p]; + } + } +out: + tsk_bit_array_free(&child_samples); + return ret; +} + +static int +tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), + tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, + const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) +{ + int ret = 0; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + tsk_id_t *row_indexes, *col_indexes; + tsk_size_t i, j, k, r, c, row, col, *row_repeats, *col_repeats; + tsk_bit_array_t node_samples; + iter_state l_state, r_state; + double *result_tmp, *result_row; + + tsk_memset(&node_samples, 0, sizeof(node_samples)); + tsk_memset(&l_state, 0, sizeof(l_state)); + tsk_memset(&r_state, 0, sizeof(r_state)); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + if (result_tmp == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + if ((ret = iter_state_init(&l_state, self, state_dim)) != 0) { + goto out; + } + if ((ret = iter_state_init(&r_state, self, state_dim)) != 0) { + goto out; + } + if ((ret = positions_to_tree_indexes(self, row_positions, n_rows, &row_indexes)) + != 0) { + goto out; + } + if ((ret = positions_to_tree_indexes(self, col_positions, n_cols, &col_indexes)) + != 0) { + goto out; + } + if ((ret = get_index_counts(row_indexes, n_rows, &row_repeats)) != 0) { + goto out; + } + if ((ret = get_index_counts(col_indexes, n_cols, &col_repeats)) != 0) { + goto out; + } + if ((ret = get_node_samples(self, state_dim, sample_sets, &node_samples)) != 0) { + goto out; + } + iter_state_clear(&l_state, state_dim, num_nodes, &node_samples); + row = 0; + for (r = 0; (tsk_id_t) r < (row_indexes[n_rows - 1] - row_indexes[0] + 1); r++) { + tsk_memset(result_tmp, 0, result_dim * sizeof(*result_tmp)); + iter_state_clear(&r_state, state_dim, num_nodes, &node_samples); + ret = advance_collect_edges(&l_state, (tsk_id_t) r + row_indexes[0]); + if (ret != 0) { + goto out; + } + result_row = GET_2D_ROW(result, result_dim * n_cols, row); + ret = compute_two_tree_branch_stat( + self, &r_state, &l_state, f, f_params, result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + col = 0; + for (c = 0; (tsk_id_t) c < (col_indexes[n_cols - 1] - col_indexes[0] + 1); c++) { + ret = advance_collect_edges(&r_state, (tsk_id_t) c + col_indexes[0]); + if (ret != 0) { + goto out; + } + ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, + result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + for (i = 0; i < row_repeats[r]; i++) { + for (j = 0; j < col_repeats[c]; j++) { + result_row = GET_2D_ROW(result, result_dim * n_cols, row + i); + for (k = 0; k < result_dim; k++) { + result_row[col + (j * result_dim) + k] = result_tmp[k]; + } + } + } + col += (col_repeats[c] * result_dim); + } + row += row_repeats[r]; + } +out: + tsk_safe_free(result_tmp); + tsk_safe_free(row_indexes); + tsk_safe_free(col_indexes); + tsk_safe_free(row_repeats); + tsk_safe_free(col_repeats); + iter_state_free(&l_state); + iter_state_free(&r_state); + tsk_bit_array_free(&node_samples); + return ret; +} + static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, - tsk_size_t out_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) { // TODO: generalize this function if we ever decide to do weighted two_locus stats. // We only implement count stats and therefore we don't handle weights. @@ -2658,6 +3213,11 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); + // We do not support two-locus node stats + if (!!(options & TSK_STAT_NODE)) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } // If no mode is specified, we default to site mode if (!(stat_site || stat_branch)) { stat_site = true; @@ -2696,8 +3256,20 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - } else { - ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + } else if (stat_branch) { + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, &sample_sets_bits, + result_dim, f, &f_params, norm_f, out_rows, row_positions, out_cols, + col_positions, options, result); } out: @@ -3527,13 +4099,15 @@ D_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3564,12 +4138,14 @@ D2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3602,12 +4178,13 @@ r2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, num_rows, - row_sites, num_cols, col_sites, options, result); + row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } static int @@ -3643,13 +4220,15 @@ D_prime_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_hap_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3682,13 +4261,15 @@ r_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3720,12 +4301,14 @@ Dz_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3754,12 +4337,127 @@ pi2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} + +static int +D2_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * ((w_aB * w_aB * (w_Ab - 1) * w_Ab) + + ((w_ab - 1) * w_ab * (w_AB - 1) * w_AB) + - (w_aB * w_Ab * (w_Ab + (2 * w_ab * w_AB) - 1))); + } + return 0; +} + +int +tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +} + +static int +Dz_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * ((((w_AB * w_ab) - (w_Ab * w_aB)) * (w_aB + w_ab - w_AB - w_Ab) + * (w_Ab + w_ab - w_AB - w_aB)) + - ((w_AB * w_ab) * (w_AB + w_ab - w_Ab - w_aB - 2)) + - ((w_Ab * w_aB) * (w_Ab + w_aB - w_AB - w_ab - 2))); + } + return 0; +} + +int +tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +} + +static int +pi2_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] + = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * (((w_AB + w_Ab) * (w_aB + w_ab) * (w_AB + w_aB) * (w_Ab + w_ab)) + - ((w_AB * w_ab) * (w_AB + w_ab + (3 * w_Ab) + (3 * w_aB) - 1)) + - ((w_Ab * w_aB) * (w_Ab + w_aB + (3 * w_AB) + (3 * w_ab) - 1))); + } + return 0; +} + +int +tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); } /*********************************** diff --git a/c/tskit/trees.h b/c/tskit/trees.h index d7b64d0701..b23fa55320 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1073,36 +1073,59 @@ int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, - tsk_size_t num_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result); + const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c403b12af6..4beca9a87f 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10135,24 +10135,71 @@ TreeSequence_pair_coalescence_quantiles( return ret; } +static PyArrayObject * +parse_sites(TreeSequence *self, PyObject *sites, npy_intp *out_dim) +{ + PyArrayObject *array; + tsk_size_t num_sites = tsk_treeseq_get_num_sites(self->tree_sequence); + + if (sites == Py_None) { + array = (PyArrayObject *) PyArray_Arange(0, num_sites, 1, NPY_INT32); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0); + } else { + array = (PyArrayObject *) PyArray_FROMANY( + sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0); + } + +out: + return array; +} + +static PyArrayObject * +parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) +{ + PyArrayObject *array; + + if (positions == Py_None) { + array = (PyArrayObject *) TreeSequence_get_breakpoints(self); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0) - 1; // NB the last element must be truncated + } else { + array = (PyArrayObject *) PyArray_FROMANY( + positions, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0); + } +out: + return array; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) { PyObject *ret = NULL; - static char *kwlist[] - = { "sample_set_sizes", "sample_sets", "row_sites", "col_sites", "mode", NULL }; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "row_sites", + "col_sites", "row_positions", "column_positions", "mode", NULL }; - PyObject *row_sites = NULL; - PyObject *col_sites = NULL; - PyObject *sample_set_sizes = NULL; - PyObject *sample_sets = NULL; - PyArrayObject *sample_set_sizes_array = NULL; - PyArrayObject *sample_sets_array = NULL; - PyArrayObject *row_sites_array = NULL; - PyArrayObject *col_sites_array = NULL; - PyArrayObject *result_matrix = NULL; - npy_intp result_shape[3]; + PyObject *row_sites = NULL, *col_sites = NULL, *row_positions = NULL, + *col_positions = NULL, *sample_set_sizes = NULL, *sample_sets = NULL; + PyArrayObject *row_sites_array = NULL, *col_sites_array = NULL, + *row_positions_array = NULL, *col_positions_array = NULL, + *sample_sets_array = NULL, *sample_set_sizes_array = NULL, + *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL, *col_sites_parsed = NULL; + double *row_positions_parsed = NULL, *col_positions_parsed = NULL; + npy_intp result_dim[3]; char *mode = NULL; tsk_size_t num_sample_sets; tsk_flags_t options = 0; @@ -10161,8 +10208,9 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOs", kwlist, &sample_set_sizes, - &sample_sets, &row_sites, &col_sites, &mode)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOO|s", kwlist, &sample_set_sizes, + &sample_sets, &row_sites, &col_sites, &row_positions, &col_positions, + &mode)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -10173,22 +10221,37 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, != 0) { goto out; } - row_sites_array = (PyArrayObject *) PyArray_FROMANY( - row_sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); - if (row_sites_array == NULL) { - goto out; - } - col_sites_array = (PyArrayObject *) PyArray_FROMANY( - col_sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); - if (col_sites_array == NULL) { - goto out; + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); } - result_shape[0] = PyArray_DIM(row_sites_array, 0); - result_shape[1] = PyArray_DIM(col_sites_array, 0); - result_shape[2] = num_sample_sets; - result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_shape, NPY_FLOAT64, 0); + result_dim[2] = num_sample_sets; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); if (result_matrix == NULL) { + PyErr_NoMemory(); goto out; } @@ -10196,8 +10259,8 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, Py_BEGIN_ALLOW_THREADS err = method(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - result_shape[0], PyArray_DATA(row_sites_array), result_shape[1], - PyArray_DATA(col_sites_array), options, PyArray_DATA(result_matrix)); + result_dim[0], row_sites_parsed, row_positions_parsed, result_dim[1], + col_sites_parsed, col_positions_parsed, options, PyArray_DATA(result_matrix)); Py_END_ALLOW_THREADS // clang-format on @@ -10211,8 +10274,10 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, out: Py_XDECREF(row_sites_array); Py_XDECREF(col_sites_array); - Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); Py_XDECREF(result_matrix); return ret; } @@ -10259,6 +10324,24 @@ TreeSequence_pi2_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_pi2); } +static PyObject * +TreeSequence_pi2_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_pi2_unbiased); +} + +static PyObject * +TreeSequence_D2_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_D2_unbiased); +} + +static PyObject * +TreeSequence_Dz_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_Dz_unbiased); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -11023,6 +11106,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_pi2_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the pi2 matrix." }, + { .ml_name = "D2_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_D2_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased D2 matrix." }, + { .ml_name = "Dz_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_Dz_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased Dz matrix." }, + { .ml_name = "pi2_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_pi2_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased pi2 matrix." }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 761658224e..17a88c359c 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -25,7 +25,6 @@ import contextlib import io from dataclasses import dataclass -from itertools import combinations from itertools import combinations_with_replacement from itertools import permutations from itertools import product @@ -631,14 +630,14 @@ def get_index_repeats(indices): def two_branch_count_stat( ts: tskit.TreeSequence, func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], - norm_func, # TODO: might need for polarisation + norm_func, num_sample_sets: int, sample_set_sizes: np.ndarray, sample_sets: BitSet, sample_index_map: np.ndarray, row_trees: np.ndarray, col_trees: np.ndarray, - polarised: bool, # TODO: polarisation + polarised: bool, ) -> np.ndarray: """ Compute a tree X tree LD matrix by walking along the tree sequence and @@ -1090,8 +1089,8 @@ def d2_unbiased( "D_prime": D_prime_summary_func, "pi2": pi2_summary_func, "Dz": Dz_summary_func, - "d2_unbiased": d2_unbiased, - "dz_unbiased": dz_unbiased, + "D2_unbiased": d2_unbiased, + "Dz_unbiased": dz_unbiased, "pi2_unbiased": pi2_unbiased, } @@ -1103,9 +1102,9 @@ def d2_unbiased( pi2_summary_func: norm_total_weighted, r_summary_func: norm_total_weighted, r2_summary_func: norm_hap_weighted, - d2_unbiased: None, - dz_unbiased: None, - pi2_unbiased: None, + d2_unbiased: norm_total_weighted, + dz_unbiased: norm_total_weighted, + pi2_unbiased: norm_total_weighted, } POLARIZATION = { @@ -1276,7 +1275,11 @@ def test_subset_positions(partition): bp = ts.breakpoints(as_array=True) mid = (bp[1:] + bp[:-1]) / 2 np.testing.assert_allclose( - ld_matrix(ts, mode="branch", stat="d2_unbiased", positions=[mid[a], mid[b]]), + ld_matrix(ts, mode="branch", stat="D2_unbiased", positions=[mid[a], mid[b]]), + PAPER_EX_BRANCH_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], + ) + np.testing.assert_allclose( + ts.ld_matrix(mode="branch", stat="D2_unbiased", positions=[mid[a], mid[b]]), PAPER_EX_BRANCH_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], ) @@ -1321,7 +1324,13 @@ def test_subset_positions_one_list(tree_index): bp = ts.breakpoints(as_array=True) mid = (bp[1:] + bp[:-1]) / 2 np.testing.assert_allclose( - ld_matrix(ts, mode="branch", stat="d2_unbiased", positions=[mid[tree_index]]), + ld_matrix(ts, mode="branch", stat="D2_unbiased", positions=[mid[tree_index]]), + PAPER_EX_BRANCH_TRUTH_MATRIX[ + tree_index[0] : tree_index[-1] + 1, tree_index[0] : tree_index[-1] + 1 + ], + ) + np.testing.assert_allclose( + ts.ld_matrix(mode="branch", stat="D2_unbiased", positions=[mid[tree_index]]), PAPER_EX_BRANCH_TRUTH_MATRIX[ tree_index[0] : tree_index[-1] + 1, tree_index[0] : tree_index[-1] + 1 ], @@ -1363,7 +1372,11 @@ def test_repeated_position_elements(tree_index): np.testing.assert_allclose( truth, - ld_matrix(ts, mode="branch", stat="d2_unbiased", positions=[l_pos, r_pos]), + ld_matrix(ts, mode="branch", stat="D2_unbiased", positions=[l_pos, r_pos]), + ) + np.testing.assert_allclose( + truth, + ts.ld_matrix(mode="branch", stat="D2_unbiased", positions=[l_pos, r_pos]), ) @@ -1378,7 +1391,7 @@ def test_sample_sets(partition): :param partition: length 2 list of [ss_1, ss_2]. """ ts = get_paper_ex_ts() - np.testing.assert_array_almost_equal( + np.testing.assert_allclose( ld_matrix(ts, sample_sets=partition), ts.ld_matrix(sample_sets=partition) ) @@ -1394,7 +1407,7 @@ def test_compare_to_ld_calculator(): @pytest.mark.parametrize( "stat", - sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), + sorted(SUMMARY_FUNCS.keys()), ) def test_multiallelic_with_back_mutation(stat): ts = msprime.sim_ancestry( @@ -1417,7 +1430,7 @@ def test_multiallelic_with_back_mutation(stat): # TODO: port unbiased summary functions @pytest.mark.parametrize( "stat", - sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), + sorted(SUMMARY_FUNCS.keys()), ) def test_ld_matrix(ts, stat): np.testing.assert_array_almost_equal( @@ -1479,7 +1492,9 @@ def __init__(self, ts, sample_sets, num_sample_sets, sample_index_map): for n in range(ts.num_nodes): for k in range(num_sample_sets): if sample_sets.contains(k, sample_index_map[n]): - self.node_samples.add((num_sample_sets * n) + k, n) + self.node_samples.add( + (num_sample_sets * n) + k, sample_index_map[n] + ) # these are empty for the uninitialized state (index = -1) self.edges_in = [] self.edges_out = [] @@ -1730,188 +1745,22 @@ def compute_branch_stat(ts, stat_func, stat, params, state_dim, l_state, r_state return stat, r_state -# What follows is an implementation of two-locus statistics as described in -# McVean 2002 (https://doi.org/10.1093/genetics/162.2.987). We compute the -# covariance between coalescent times to produce expectations of coalescent -# times between three sampling patterns of samples. These expectations can be -# compined to produce D2, Dz, and pi2. These are for testing and to demonstrate -# conceptual parity between our method and McVean's method. - - -def tmrca(tr, x, y): - """ - Mirror the functionality in the branch two-locus stats. We want to compute - the contribution of each subset of samples. If there is no most recent common - ancestor, we walk up the tree and find each sample's individual MRCA (which - as written is realy just the root of the tree). This is to work around the case - of empty, gapped, and decapitated trees. - """ - try: - # First, we try to get the tmrca - return tr.tmrca(x, y) - except ValueError as e: - # If we cannot, crawl up as far as the sample is connected - x_mrca, y_mrca = -1, -1 - if "not share a common ancestor" not in str(e): - raise e - for r in tr.roots: - if x in set(tr.samples(r)): - x_mrca = r - if y in set(tr.samples(r)): - y_mrca = r - if x_mrca == -1 or y_mrca == -1: - raise ValueError - return (tr.time(x_mrca) + tr.time(y_mrca)) / 2 - - -def compute_D2(x, y, ij, ijk, ijkl): - E_ijij = 0 - E_ijik = 0 - E_ijkl = 0 - if len(ij) == 0 or len(ijk) == 0 or len(ijkl) == 0: - # this method requires at least 4 samples - return float("nan") - for i, j in ij: - i_time = x.time(i) - j_time = x.time(j) - ij_time = (i_time + j_time) / 2 - E_ijij += (tmrca(x, i, j) - ij_time) * (tmrca(y, i, j) - ij_time) - for i, j, k in ijk: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - ij_time = (i_time + j_time) / 2 - ik_time = (i_time + k_time) / 2 - E_ijik += (tmrca(x, i, j) - ij_time) * (tmrca(y, i, k) - ik_time) - for i, j, k, l in ijkl: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - l_time = x.time(l) - ij_time = (i_time + j_time) / 2 - kl_time = (k_time + l_time) / 2 - E_ijkl += (tmrca(x, i, j) - ij_time) * (tmrca(y, k, l) - kl_time) - E_ijij = E_ijij / len(ij) - E_ijik = E_ijik / len(ijk) - E_ijkl = E_ijkl / len(ijkl) - return E_ijij - 2 * E_ijik + E_ijkl - - -def compute_Dz(x, y, ij, ijk, ijkl): - E_ijik = 0 - E_ijkl = 0 - if len(ijk) == 0 or len(ijkl) == 0: - # this method requires at least 4 samples - return float("nan") - for i, j, k in ijk: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - ij_time = (i_time + j_time) / 2 - ik_time = (i_time + k_time) / 2 - E_ijik += (tmrca(x, i, j) - ij_time) * (tmrca(y, i, k) - ik_time) - for i, j, k, l in ijkl: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - l_time = x.time(l) - ij_time = (i_time + j_time) / 2 - kl_time = (k_time + l_time) / 2 - E_ijkl += (tmrca(x, i, j) - ij_time) * (tmrca(y, k, l) - kl_time) - E_ijik = E_ijik / len(ijk) - E_ijkl = E_ijkl / len(ijkl) - return 4 * (E_ijik - E_ijkl) - - -def compute_pi2(x, y, ij, ijk, ijkl): - E_ijkl = 0 - if len(ijkl) == 0: - # this method requires at least 4 samples - return float("nan") - for i, j, k, l in ijkl: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - l_time = x.time(l) - ij_time = (i_time + j_time) / 2 - kl_time = (k_time + l_time) / 2 - E_ijkl += (tmrca(x, i, j) - ij_time) * (tmrca(y, k, l) - kl_time) - E_ijkl = E_ijkl / len(ijkl) - return E_ijkl - - -def combine(samples): - # All combinations where i != j - ij = list(combinations(samples, 2)) - # All combinations where i != {j,k} and j != k - ijk = [ - (i, j, k) - for i, j, k in product(samples, repeat=3) - if i != k and i != j and j != k - ] - # All combinations where i != {k,l} and j != {k,l} - ijkl = [ - (i, j, samples[k], samples[l]) - for i, j in combinations(samples, 2) - for k in range(len(samples)) - for l in range(k + 1, len(samples)) # noqa: E741 - if i != samples[k] and j != samples[k] and samples[l] != i and samples[l] != j - ] - return ij, ijk, ijkl - - -def naive_matrix(ts, stat_func, sample_set=None): - """Compute a tree x tree LD matrix for a given tree sequence and two-locus - statistic. This produces a matrix of LD that is generated from the - covariance in gene genealogies, as described in McVean 2002. - - :param ts: Tree sequence to gather data from. - :param stat_func: Function to compute a two-locus statistic from two - materialized trees and sample combinations. - :returns: Pairwise branch LD matrix for an entire tree sequence. - """ - result = np.zeros((ts.num_trees, ts.num_trees), dtype=np.float64) - # These stats require at least 4 samples in the tree - ij, ijk, ijkl = combine(sample_set or ts.samples()) - for i, j in combinations_with_replacement(range(ts.num_trees), 2): - val = stat_func(ts.at_index(i), ts.at_index(j), ij, ijk, ijkl) - result[i, j] = val - tri_idx = np.tril_indices(len(result), k=-1) - result[tri_idx] = result.T[tri_idx] - return result - - @pytest.mark.parametrize( "ts", [ ts for ts in get_example_tree_sequences() - # no_samples and empty_ts aren't handled here. if ts.id - in { - # We only perform tests on a useful subset of the example trees due to - # runtime constraints of the naive McVean implementation. We plan to expand - # coverage to more examples after implementing the C version - "all_nodes_samples", - "internal_nodes_samples", - "mixed_internal_leaf_samples", - "n=2_m=32_rho=0.5", - "bottleneck_n=10_mutated", - "rev_node_order", - "decapitate", - } + not in { + "no_samples", + "empty_ts", + } # , "bottleneck_n=100_mutated", "n=100_m=32_rho=0.1", "n=100_m=32_rho=0.5"} ], ) -@pytest.mark.parametrize( - "stat,stat_func", - zip( - ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], - [compute_D2, compute_Dz, compute_pi2], - ), -) -def test_branch_ld_matrix(ts, stat, stat_func): +@pytest.mark.parametrize("stat", sorted(SUMMARY_FUNCS.keys())) +def test_branch_ld_matrix(ts, stat): np.testing.assert_array_almost_equal( - ld_matrix(ts, stat=stat, mode="branch"), naive_matrix(ts, stat_func) + ts.ld_matrix(stat=stat, mode="branch"), ld_matrix(ts, stat=stat, mode="branch") ) @@ -1941,15 +1790,11 @@ def get_test_branch_sample_set_test_cases(): @pytest.mark.parametrize("ts,sample_set", get_test_branch_sample_set_test_cases()) -@pytest.mark.parametrize( - "stat,stat_func", - zip( - ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], - [compute_D2, compute_Dz, compute_pi2], - ), -) -def test_branch_ld_matrix_sample_sets(ts, sample_set, stat, stat_func): +@pytest.mark.parametrize("stat", sorted(SUMMARY_FUNCS.keys())) +def test_branch_ld_matrix_sample_sets(ts, sample_set, stat): np.testing.assert_array_almost_equal( - ld_matrix(ts, stat=stat, mode="branch", sample_sets=sample_set), - naive_matrix(ts, stat_func, sample_set[0]), + np.expand_dims( + ld_matrix(ts, stat=stat, mode="branch", sample_sets=sample_set), axis=0 + ), + ts.ld_matrix(stat=stat, mode="branch", sample_sets=sample_set), ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 36706bdf50..942c9021fc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7576,31 +7576,53 @@ def __one_way_sample_set_stat( stat = stat[()] return stat + def parse_sites(self, sites): + row_sites, col_sites = None, None + if sites is not None: + if any(not hasattr(a, "__getitem__") or isinstance(a, str) for a in sites): + raise ValueError("Sites must be a list of lists, tuples, or ndarrays") + if len(sites) == 2: + row_sites, col_sites = sites + elif len(sites) == 1: + row_sites = col_sites = sites[0] + else: + raise ValueError( + f"Sites must be a length 1 or 2 list, got a length {len(sites)} list" + ) + return row_sites, col_sites + + def parse_positions(self, positions): + row_positions, col_positions = None, None + if positions is not None: + if any( + not hasattr(a, "__getitem__") or isinstance(a, str) for a in positions + ): + raise ValueError( + "Positions must be a list of lists, tuples, or ndarrays" + ) + if len(positions) == 2: + row_positions, col_positions = positions + elif len(positions) == 1: + row_positions = col_positions = positions[0] + else: + raise ValueError( + "Positions must be a length 1 or 2 list, " + f"got a length {len(positions)} list" + ) + return row_positions, col_positions + def __two_locus_sample_set_stat( self, ll_method, sample_sets, sites=None, + positions=None, mode=None, ): if sample_sets is None: sample_sets = self.samples() - if sites is not None and any( - not hasattr(a, "__getitem__") or isinstance(a, str) for a in sites - ): - raise ValueError("Sites must be a list of lists, tuples, or ndarrays") - - if sites is None: - row_sites = np.arange(self.num_sites, dtype=np.int32) - col_sites = np.arange(self.num_sites, dtype=np.int32) - elif len(sites) == 2: - row_sites, col_sites = sites - elif len(sites) == 1: - row_sites = col_sites = sites[0] - else: - raise ValueError( - f"Sites must be a length 1 or 2 list, got a length {len(sites)} list" - ) + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. @@ -7624,7 +7646,15 @@ def __two_locus_sample_set_stat( flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) - result = ll_method(sample_set_sizes, flattened, row_sites, col_sites, mode) + result = ll_method( + sample_set_sizes, + flattened, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) if drop_dimension: result = result.reshape(result.shape[:2]) @@ -9522,7 +9552,9 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time - def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): + def ld_matrix( + self, sample_sets=None, sites=None, positions=None, mode="site", stat="r2" + ): stats = { "D": self._ll_tree_sequence.D_matrix, "D2": self._ll_tree_sequence.D2_matrix, @@ -9531,6 +9563,9 @@ def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): "r": self._ll_tree_sequence.r_matrix, "Dz": self._ll_tree_sequence.Dz_matrix, "pi2": self._ll_tree_sequence.pi2_matrix, + "Dz_unbiased": self._ll_tree_sequence.Dz_unbiased_matrix, + "D2_unbiased": self._ll_tree_sequence.D2_unbiased_matrix, + "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_matrix, } try: @@ -9544,6 +9579,7 @@ def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): two_locus_stat, sample_sets, sites=sites, + positions=positions, mode=mode, )