Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in FsaFromTensor for empty FSA; make index select from ragged… #481

Merged
merged 1 commit into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions k2/csrc/fsa.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ Fsa FsaFromArray1(Array1<Arc> &array, bool *error) {
const Arc *arcs_data = array.Data();
ContextPtr &c = array.Context();
int32_t num_arcs = array.Dim();
// We choose to return an Fsa with no states and no arcs. We could also have
// chosen to return an Fsa with 2 states and no arcs.
if (num_arcs == 0)
return Fsa(EmptyRaggedShape(c, 2), Array1<Arc>(c, 0));
*error = false;

// If the FSA has arcs entering the final state, that will
Expand Down
12 changes: 7 additions & 5 deletions k2/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,17 @@ std::ostream &operator<<(std::ostream &os, const DenseFsaVec &dfsavec);
*/
Fsa FsaFromTensor(Tensor &t, bool *error);


Fsa FsaFromArray1(Array1<Arc> &arc, bool *error);

/*
Returns a single Tensor that represents the FSA; this is just the vector of
Arc reinterpreted as num_arcs by 4 Tensor of int32_t. It can be converted
back to an equivalent FSA using `FsaFromTensor`. Notice: this is not the
same format as we use to serialize FsaVec.
Arc reinterpreted as a (num_arcs by 4) Tensor of int32_t. It can be converted
back to an equivalent FSA using `FsaFromTensor`. Notice: this is not the same
format as we use to serialize FsaVec. Also the round-trip conversion to
Tensor and back may not preserve the number of states for FSAs that had no
arcs entering the final-state, since we have to guess the number of states in
this case.

It is an error if `fsa.NumAxes() != 2`.
*/
Expand Down Expand Up @@ -266,8 +270,6 @@ Tensor FsaVecToTensor(const FsaVec &fsa_vec);
*/
FsaVec FsaVecFromTensor(Tensor &t, bool *error);

FsaVec FsaVecFromArray1(Array1<Arc> &arc, bool *error); // TODO: implement it

/*
Return one Fsa in an FsaVec. Note, this has to make copies of the
row offsets and strides but can use a sub-array of the arcs array
Expand Down
25 changes: 19 additions & 6 deletions k2/python/csrc/torch/index_select.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ static torch::Tensor SimpleRaggedIndexSelect1D(torch::Tensor src,
int64_t src_stride = src.strides()[0];
int64_t ans_stride = ans.strides()[0];

#if !defined(NDEBUG)
// check if there is at most one non-zero element in src for each sub-list
Ragged<int32_t> non_zero_elems(indexes.shape,
Array1<int32_t>(context, indexes_num_elems));
Expand All @@ -189,15 +188,29 @@ static torch::Tensor SimpleRaggedIndexSelect1D(torch::Tensor src,
Array1<int32_t> counts(context, indexes_dim0);
SumPerSublist(non_zero_elems, 0, &counts);
const int32_t *counts_data = counts.Data();
Array1<int32_t> status(context, 1, 0);
Array1<int32_t> status(context, 1, 0); // 0 -> success; otherwise 1 + row_id
// of bad row in `indexes`
int32_t *status_data = status.Data();
K2_EVAL(
context, counts.Dim(), lambda_check_status, (int32_t i)->void {
if (counts_data[i] > 1) status_data[0] = 1;
if (counts_data[i] > 1) status_data[0] = 1 + i;
});
K2_CHECK_EQ(status[0], 0) << "There must be at most one non-zero "
"element in src for any sub-list in indexes";
#endif
int32_t s = status[0];
if (s != 0) {
Array1<T> indexed_values(context, indexes_num_elems);
T *indexed_values_data = indexed_values.Data();
K2_EVAL(context, indexes_num_elems, lambda_set_values, (int32_t i) -> void {
int32_t src_index = indexes_data[i];
indexed_values_data[i] = src_data[src_index * src_stride];
});
Array1<int32_t> row_splits = indexes.RowSplits(1);

K2_LOG(FATAL) << "There must be at most one non-zero "
"element in src for any sub-list in indexes; sub-list "
<< (s-1) << " has too many elements: "
<< indexed_values.Arange(row_splits[s-1],
row_splits[s]);
}

K2_EVAL(
context, indexes_num_elems, lambda_set_ans_data, (int32_t i)->void {
Expand Down
3 changes: 2 additions & 1 deletion k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def determinize(fsa: Fsa) -> Fsa:
'''Determinize the input Fsa.

Caution:
It only works on for CPU and doesn't support autograd.
It only works on for CPU and doesn't support autograd (for now;
this is not a fundamental limitation).

Args:
fsa:
Expand Down