diff --git a/k2/csrc/fsa.cu b/k2/csrc/fsa.cu index 585488bb4..4bff92d11 100644 --- a/k2/csrc/fsa.cu +++ b/k2/csrc/fsa.cu @@ -238,6 +238,10 @@ Fsa FsaFromArray1(Array1 &array, bool *error) { const Arc *arcs_data = array.Data(); ContextPtr &c = array.Context(); int32_t num_arcs = array.Dim(); + // We choose to return an Fsa with no states and no arcs. We could also have + // chosen to return an Fsa with 2 states and no arcs. + if (num_arcs == 0) + return Fsa(EmptyRaggedShape(c, 2), Array1(c, 0)); *error = false; // If the FSA has arcs entering the final state, that will diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index 23599300d..a3028d883 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -200,13 +200,17 @@ std::ostream &operator<<(std::ostream &os, const DenseFsaVec &dfsavec); */ Fsa FsaFromTensor(Tensor &t, bool *error); + Fsa FsaFromArray1(Array1 &arc, bool *error); /* Returns a single Tensor that represents the FSA; this is just the vector of - Arc reinterpreted as num_arcs by 4 Tensor of int32_t. It can be converted - back to an equivalent FSA using `FsaFromTensor`. Notice: this is not the - same format as we use to serialize FsaVec. + Arc reinterpreted as a (num_arcs by 4) Tensor of int32_t. It can be converted + back to an equivalent FSA using `FsaFromTensor`. Notice: this is not the same + format as we use to serialize FsaVec. Also the round-trip conversion to + Tensor and back may not preserve the number of states for FSAs that had no + arcs entering the final-state, since we have to guess the number of states in + this case. It is an error if `fsa.NumAxes() != 2`. */ @@ -266,8 +270,6 @@ Tensor FsaVecToTensor(const FsaVec &fsa_vec); */ FsaVec FsaVecFromTensor(Tensor &t, bool *error); -FsaVec FsaVecFromArray1(Array1 &arc, bool *error); // TODO: implement it - /* Return one Fsa in an FsaVec. Note, this has to make copies of the row offsets and strides but can use a sub-array of the arcs array diff --git a/k2/python/csrc/torch/index_select.cu b/k2/python/csrc/torch/index_select.cu index 80a9e39f9..bca7fff2f 100644 --- a/k2/python/csrc/torch/index_select.cu +++ b/k2/python/csrc/torch/index_select.cu @@ -173,7 +173,6 @@ static torch::Tensor SimpleRaggedIndexSelect1D(torch::Tensor src, int64_t src_stride = src.strides()[0]; int64_t ans_stride = ans.strides()[0]; -#if !defined(NDEBUG) // check if there is at most one non-zero element in src for each sub-list Ragged non_zero_elems(indexes.shape, Array1(context, indexes_num_elems)); @@ -189,15 +188,29 @@ static torch::Tensor SimpleRaggedIndexSelect1D(torch::Tensor src, Array1 counts(context, indexes_dim0); SumPerSublist(non_zero_elems, 0, &counts); const int32_t *counts_data = counts.Data(); - Array1 status(context, 1, 0); + Array1 status(context, 1, 0); // 0 -> success; otherwise 1 + row_id + // of bad row in `indexes` int32_t *status_data = status.Data(); K2_EVAL( context, counts.Dim(), lambda_check_status, (int32_t i)->void { - if (counts_data[i] > 1) status_data[0] = 1; + if (counts_data[i] > 1) status_data[0] = 1 + i; }); - K2_CHECK_EQ(status[0], 0) << "There must be at most one non-zero " - "element in src for any sub-list in indexes"; -#endif + int32_t s = status[0]; + if (s != 0) { + Array1 indexed_values(context, indexes_num_elems); + T *indexed_values_data = indexed_values.Data(); + K2_EVAL(context, indexes_num_elems, lambda_set_values, (int32_t i) -> void { + int32_t src_index = indexes_data[i]; + indexed_values_data[i] = src_data[src_index * src_stride]; + }); + Array1 row_splits = indexes.RowSplits(1); + + K2_LOG(FATAL) << "There must be at most one non-zero " + "element in src for any sub-list in indexes; sub-list " + << (s-1) << " has too many elements: " + << indexed_values.Arange(row_splits[s-1], + row_splits[s]); + } K2_EVAL( context, indexes_num_elems, lambda_set_ans_data, (int32_t i)->void { diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index e9e2f6991..0c14d945b 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -299,7 +299,8 @@ def determinize(fsa: Fsa) -> Fsa: '''Determinize the input Fsa. Caution: - It only works on for CPU and doesn't support autograd. + It only works on for CPU and doesn't support autograd (for now; + this is not a fundamental limitation). Args: fsa: