Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #326

Merged
merged 18 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
699 changes: 223 additions & 476 deletions src/aln.cpp

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions src/aln.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,23 @@ struct AlignmentStatistics {
}
};

struct mapping_params {
struct MappingParameters {
int r { 150 };
int max_secondary { 0 };
float dropoff_threshold { 0.5 };
int R { 2 };
int maxTries { 20 };
int rescue_level { 2 };
int max_tries { 20 };
int rescue_cutoff;
bool is_sam_out { true };
CigarOps cigar_ops{CigarOps::M};
bool output_unmapped { true };
bool details{false};

void verify() const {
if (max_tries < 1) {
throw BadParameter("max_tries must be greater than zero");
}
}
};

class i_dist_est {
Expand All @@ -88,7 +94,7 @@ void align_PE_read(
AlignmentStatistics& statistics,
i_dist_est& isize_est,
const Aligner& aligner,
const mapping_params& map_param,
const MappingParameters& map_param,
const IndexParameters& index_parameters,
const References& references,
const StrobemerIndex& index
Expand All @@ -100,7 +106,7 @@ void align_SE_read(
std::string& outstring,
AlignmentStatistics& statistics,
const Aligner& aligner,
const mapping_params& map_param,
const MappingParameters& map_param,
const IndexParameters& index_parameters,
const References& references,
const StrobemerIndex& index
Expand Down
4 changes: 2 additions & 2 deletions src/cmdline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) {
// Search parameters
if (f) { opt.f = args::get(f); }
if (S) { opt.dropoff_threshold = args::get(S); }
if (M) { opt.maxTries = args::get(M); }
if (R) { opt.R = args::get(R); }
if (M) { opt.max_tries = args::get(M); }
if (R) { opt.rescue_level = args::get(R); }

// Reference and read files
opt.ref_filename = args::get(ref_filename);
Expand Down
4 changes: 2 additions & 2 deletions src/cmdline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ struct CommandLineOptions {
// Search parameters
float f { 0.0002 };
float dropoff_threshold { 0.5 };
int maxTries { 20 };
int R { 2 };
int max_tries { 20 };
int rescue_level { 2 };

// Reference and read files
std::string ref_filename; // This is either a fasta file or an index file - if fasta, indexing will be run
Expand Down
15 changes: 8 additions & 7 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ void warn_if_no_optimizations() {
}
}

void log_parameters(const IndexParameters& index_parameters, const mapping_params& map_param, const alignment_params& aln_params) {
void log_parameters(const IndexParameters& index_parameters, const MappingParameters& map_param, const alignment_params& aln_params) {
logger.debug() << "Using" << std::endl
<< "k: " << index_parameters.syncmer.k << std::endl
<< "s: " << index_parameters.syncmer.s << std::endl
<< "w_min: " << index_parameters.randstrobe.w_min << std::endl
<< "w_max: " << index_parameters.randstrobe.w_max << std::endl
<< "Read length (r): " << map_param.r << std::endl
<< "Maximum seed length: " << index_parameters.randstrobe.max_dist + index_parameters.syncmer.k << std::endl
<< "R: " << map_param.R << std::endl
<< "R: " << map_param.rescue_level << std::endl
<< "Expected [w_min, w_max] in #syncmers: [" << index_parameters.randstrobe.w_min << ", " << index_parameters.randstrobe.w_max << "]" << std::endl
<< "Expected [w_min, w_max] in #nucleotides: [" << (index_parameters.syncmer.k - index_parameters.syncmer.s + 1) * index_parameters.randstrobe.w_min << ", " << (index_parameters.syncmer.k - index_parameters.syncmer.s + 1) * index_parameters.randstrobe.w_max << "]" << std::endl
<< "A: " << aln_params.match << std::endl
Expand Down Expand Up @@ -168,16 +168,17 @@ int run_strobealign(int argc, char **argv) {
aln_params.gap_extend = opt.E;
aln_params.end_bonus = opt.end_bonus;

mapping_params map_param;
MappingParameters map_param;
map_param.r = opt.r;
map_param.max_secondary = opt.max_secondary;
map_param.dropoff_threshold = opt.dropoff_threshold;
map_param.R = opt.R;
map_param.maxTries = opt.maxTries;
map_param.rescue_level = opt.rescue_level;
map_param.max_tries = opt.max_tries;
map_param.is_sam_out = opt.is_sam_out;
map_param.cigar_ops = opt.cigar_eqx ? CigarOps::EQX : CigarOps::M;
map_param.output_unmapped = opt.output_unmapped;
map_param.details = opt.details;
map_param.verify();

log_parameters(index_parameters, map_param, aln_params);
logger.debug() << "Threads: " << opt.n_threads << std::endl;
Expand Down Expand Up @@ -257,7 +258,7 @@ int run_strobealign(int argc, char **argv) {
// Map/align reads

Timer map_align_timer;
map_param.rescue_cutoff = map_param.R < 100 ? map_param.R * index.filter_cutoff : 1000;
map_param.rescue_cutoff = map_param.rescue_level < 100 ? map_param.rescue_level * index.filter_cutoff : 1000;
logger.debug() << "Using rescue cutoff: " << map_param.rescue_cutoff << std::endl;

std::streambuf* buf;
Expand Down Expand Up @@ -331,7 +332,7 @@ int main(int argc, char **argv) {
try {
return run_strobealign(argc, argv);
} catch (BadParameter& e) {
logger.error() << "A mapping or seeding parameter is invalid: " << e.what() << std::endl;
logger.error() << "A parameter is invalid: " << e.what() << std::endl;
} catch (const std::runtime_error& e) {
logger.error() << "strobealign: " << e.what() << std::endl;
}
Expand Down
74 changes: 37 additions & 37 deletions src/nam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@
namespace {

struct Hit {
int query_s;
int query_e;
int ref_s;
int ref_e;
int query_start;
int query_end;
int ref_start;
int ref_end;
bool is_rc = false;
};

void add_to_hits_per_ref(
robin_hood::unordered_map<unsigned int, std::vector<Hit>>& hits_per_ref,
int query_s,
int query_e,
int query_start,
int query_end,
bool is_rc,
const StrobemerIndex& index,
size_t position,
int min_diff
) {
for (const auto hash = index.get_hash(position); index.get_hash(position) == hash; ++position) {
int ref_s = index.get_strobe1_position(position);
int ref_e = ref_s + index.strobe2_offset(position) + index.k();
int diff = std::abs((query_e - query_s) - (ref_e - ref_s));
int ref_start = index.get_strobe1_position(position);
int ref_end = ref_start + index.strobe2_offset(position) + index.k();
int diff = std::abs((query_end - query_start) - (ref_end - ref_start));
if (diff <= min_diff) {
hits_per_ref[index.reference_index(position)].push_back(Hit{query_s, query_e, ref_s, ref_e, is_rc});
hits_per_ref[index.reference_index(position)].push_back(Hit{query_start, query_end, ref_start, ref_end, is_rc});
min_diff = diff;
}
}
Expand All @@ -41,7 +41,7 @@ std::vector<Nam> merge_hits_into_nams(
if (sort) {
std::sort(hits.begin(), hits.end(), [](const Hit& a, const Hit& b) -> bool {
// first sort on query starts, then on reference starts
return (a.query_s < b.query_s) || ( (a.query_s == b.query_s) && (a.ref_s < b.ref_s) );
return (a.query_start < b.query_start) || ( (a.query_start == b.query_start) && (a.ref_start < b.ref_start) );
}
);
}
Expand All @@ -53,24 +53,24 @@ std::vector<Nam> merge_hits_into_nams(
for (auto & o : open_nams) {

// Extend NAM
if (( o.is_rc == h.is_rc) && (o.query_prev_hit_startpos < h.query_s) && (h.query_s <= o.query_e ) && (o.ref_prev_hit_startpos < h.ref_s) && (h.ref_s <= o.ref_e) ){
if ( (h.query_e > o.query_e) && (h.ref_e > o.ref_e) ) {
o.query_e = h.query_e;
o.ref_e = h.ref_e;
if (( o.is_rc == h.is_rc) && (o.query_prev_hit_startpos < h.query_start) && (h.query_start <= o.query_end ) && (o.ref_prev_hit_startpos < h.ref_start) && (h.ref_start <= o.ref_end) ){
if ( (h.query_end > o.query_end) && (h.ref_end > o.ref_end) ) {
o.query_end = h.query_end;
o.ref_end = h.ref_end;
// o.previous_query_start = h.query_s;
// o.previous_ref_start = h.ref_s; // keeping track so that we don't . Can be caused by interleaved repeats.
o.query_prev_hit_startpos = h.query_s; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_s; // log the last strobemer hit in case of outputting paf
o.query_prev_hit_startpos = h.query_start; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_start; // log the last strobemer hit in case of outputting paf
o.n_hits ++;
// o.score += (float)1/ (float)h.count;
is_added = true;
break;
}
else if ((h.query_e <= o.query_e) && (h.ref_e <= o.ref_e)) {
else if ((h.query_end <= o.query_end) && (h.ref_end <= o.ref_end)) {
// o.previous_query_start = h.query_s;
// o.previous_ref_start = h.ref_s; // keeping track so that we don't . Can be caused by interleaved repeats.
o.query_prev_hit_startpos = h.query_s; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_s; // log the last strobemer hit in case of outputting paf
o.query_prev_hit_startpos = h.query_start; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_start; // log the last strobemer hit in case of outputting paf
o.n_hits ++;
// o.score += (float)1/ (float)h.count;
is_added = true;
Expand All @@ -84,27 +84,27 @@ std::vector<Nam> merge_hits_into_nams(
Nam n;
n.nam_id = nam_id_cnt;
nam_id_cnt ++;
n.query_s = h.query_s;
n.query_e = h.query_e;
n.ref_s = h.ref_s;
n.ref_e = h.ref_e;
n.query_start = h.query_start;
n.query_end = h.query_end;
n.ref_start = h.ref_start;
n.ref_end = h.ref_end;
n.ref_id = ref_id;
// n.previous_query_start = h.query_s;
// n.previous_ref_start = h.ref_s;
n.query_prev_hit_startpos = h.query_s;
n.ref_prev_hit_startpos = h.ref_s;
n.query_prev_hit_startpos = h.query_start;
n.ref_prev_hit_startpos = h.ref_start;
n.n_hits = 1;
n.is_rc = h.is_rc;
// n.score += (float)1 / (float)h.count;
open_nams.push_back(n);
}

// Only filter if we have advanced at least k nucleotides
if (h.query_s > prev_q_start + k) {
if (h.query_start > prev_q_start + k) {

// Output all NAMs from open_matches to final_nams that the current hit have passed
for (auto &n : open_nams) {
if (n.query_e < h.query_s) {
if (n.query_end < h.query_start) {
int n_max_span = std::max(n.query_span(), n.ref_span());
int n_min_span = std::min(n.query_span(), n.ref_span());
float n_score;
Expand All @@ -116,10 +116,10 @@ std::vector<Nam> merge_hits_into_nams(
}

// Remove all NAMs from open_matches that the current hit have passed
auto c = h.query_s;
auto predicate = [c](decltype(open_nams)::value_type const &nam) { return nam.query_e < c; };
auto c = h.query_start;
auto predicate = [c](decltype(open_nams)::value_type const &nam) { return nam.query_end < c; };
open_nams.erase(std::remove_if(open_nams.begin(), open_nams.end(), predicate), open_nams.end());
prev_q_start = h.query_s;
prev_q_start = h.query_start;
}
}

Expand Down Expand Up @@ -180,13 +180,13 @@ std::vector<Nam> find_nams_rescue(
struct RescueHit {
unsigned int count;
size_t position;
unsigned int query_s;
unsigned int query_e;
unsigned int query_start;
unsigned int query_end;
bool is_rc;

bool operator< (const RescueHit& rhs) const {
return std::tie(count, query_s, query_e, is_rc)
< std::tie(rhs.count, rhs.query_s, rhs.query_e, rhs.is_rc);
return std::tie(count, query_start, query_end, is_rc)
< std::tie(rhs.count, rhs.query_start, rhs.query_end, rhs.is_rc);
}
};

Expand Down Expand Up @@ -218,7 +218,7 @@ std::vector<Nam> find_nams_rescue(
if ((rh.count > filter_cutoff && cnt >= 5) || rh.count > 1000) {
break;
}
add_to_hits_per_ref(hits_per_ref, rh.query_s, rh.query_e, rh.is_rc, index, rh.position, 1000);
add_to_hits_per_ref(hits_per_ref, rh.query_start, rh.query_end, rh.is_rc, index, rh.position, 1000);
cnt++;
}
}
Expand All @@ -227,6 +227,6 @@ std::vector<Nam> find_nams_rescue(
}

std::ostream& operator<<(std::ostream& os, const Nam& n) {
os << "Nam(ref_id=" << n.ref_id << ", query: " << n.query_s << ".." << n.query_e << ", ref: " << n.ref_s << ".." << n.ref_e << ", score=" << n.score << ")";
os << "Nam(ref_id=" << n.ref_id << ", query: " << n.query_start << ".." << n.query_end << ", ref: " << n.ref_start << ".." << n.ref_end << ", score=" << n.score << ")";
return os;
}
12 changes: 6 additions & 6 deletions src/nam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
// Non-overlapping approximate match
struct Nam {
int nam_id;
int query_s;
int query_e;
int query_start;
int query_end;
int query_prev_hit_startpos;
int ref_s;
int ref_e;
int ref_start;
int ref_end;
int ref_prev_hit_startpos;
int n_hits = 0;
int ref_id;
Expand All @@ -22,11 +22,11 @@ struct Nam {
bool is_rc = false;

int ref_span() const {
return ref_e - ref_s;
return ref_end - ref_start;
}

int query_span() const {
return query_e - query_s;
return query_end - query_start;
}
};

Expand Down
8 changes: 4 additions & 4 deletions src/paf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
* 12 mapping quality (0-255; 255 for missing)
*/
void output_hits_paf_PE(std::string &paf_output, const Nam &n, const std::string &query_name, const References& references, int k, int read_len) {
if (n.ref_s < 0 ) {
if (n.ref_start < 0 ) {
return;
}
paf_output.append(query_name);
paf_output.append("\t");
paf_output.append(std::to_string(read_len));
paf_output.append("\t");
paf_output.append(std::to_string(n.query_s));
paf_output.append(std::to_string(n.query_start));
paf_output.append("\t");
paf_output.append(std::to_string(n.query_prev_hit_startpos + k));
paf_output.append("\t");
Expand All @@ -32,13 +32,13 @@ void output_hits_paf_PE(std::string &paf_output, const Nam &n, const std::string
paf_output.append("\t");
paf_output.append(std::to_string(references.lengths[n.ref_id]));
paf_output.append("\t");
paf_output.append(std::to_string(n.ref_s));
paf_output.append(std::to_string(n.ref_start));
paf_output.append("\t");
paf_output.append(std::to_string(n.ref_prev_hit_startpos + k));
paf_output.append("\t");
paf_output.append(std::to_string(n.n_hits));
paf_output.append("\t");
paf_output.append(std::to_string(n.ref_prev_hit_startpos + k - n.ref_s));
paf_output.append(std::to_string(n.ref_prev_hit_startpos + k - n.ref_start));
paf_output.append("\t255\n");
}

Expand Down
2 changes: 1 addition & 1 deletion src/pc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void perform_task(
AlignmentStatistics& statistics,
int& done,
const alignment_params &aln_params,
const mapping_params &map_param,
const MappingParameters &map_param,
const IndexParameters& index_parameters,
const References& references,
const StrobemerIndex& index,
Expand Down
2 changes: 1 addition & 1 deletion src/pc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class OutputBuffer {

void perform_task(InputBuffer &input_buffer, OutputBuffer &output_buffer,
AlignmentStatistics& statistics, int& done, const alignment_params &aln_params,
const mapping_params &map_param, const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, const std::string& read_group_id);
const MappingParameters &map_param, const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, const std::string& read_group_id);

bool same_name(const std::string& n1, const std::string& n2);

Expand Down
8 changes: 4 additions & 4 deletions src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ NB_MODULE(strobealign_extension, m_) {
nb::bind_vector<QueryRandstrobeVector>(m, "QueryRandstrobeVector");

nb::class_<Nam>(m, "Nam")
.def_ro("query_start", &Nam::query_s)
.def_ro("query_end", &Nam::query_e)
.def_ro("ref_start", &Nam::ref_s)
.def_ro("ref_end", &Nam::ref_e)
.def_ro("query_start", &Nam::query_start)
.def_ro("query_end", &Nam::query_end)
.def_ro("ref_start", &Nam::ref_start)
.def_ro("ref_end", &Nam::ref_end)
.def_ro("score", &Nam::score)
.def_ro("n_hits", &Nam::n_hits)
.def_ro("reference_index", &Nam::ref_id)
Expand Down
Loading