Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor onnx math #187

Merged
merged 4 commits into from
Oct 27, 2021
Merged

Refactor onnx math #187

merged 4 commits into from
Oct 27, 2021

Conversation

JackSullivan
Copy link
Member

Description

As we add ONNXExportable to new models, there is a bunch of duplicated logic on creating OnnxML.TensorProto instances. This PR creates (hopefully) fully general methods for generating TensorProtos and replaces existing implementations with them.

@JackSullivan JackSullivan requested a review from Craigacp October 23, 2021 03:17
Copy link
Member

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice tidy up. I've got a few comments about the typing and things which will make it easier (for me) to figure out what's going on.

@@ -307,7 +307,7 @@ private ONNXOperators(String value, int numInputs, int numOptionalInputs, int nu
for (String o : outputs) {
nodeBuilder.addOutput(o);
}
nodeBuilder.setName(context.generateUniqueName(opName));
nodeBuilder.setName(context.generateUniqueName(opName) + ":" + outputs[0]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX names are required by the spec to be alphanumeric + underscores.

return OnnxMl.TensorProto.newBuilder()
.setName(context.generateUniqueName(name))
.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber())
.addAllDims(() -> dims.stream().map(Integer::longValue).iterator())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer this to just collect(Collectors.toList()) rather than make a lambda to fabricate the iterable. Plus in 17 we can replace it with .toList() which is shorter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or even just make it accept List<Long> as that's what ONNX is expecting, and the upcast on the calling side when creating the list might be done automatically?

? floatTensorBuilder(context, name, Collections.singletonList(parameters.length),
fb -> Arrays.stream(parameters).forEachOrdered(d -> fb.put((float)d)))
: doubleTensorBuilder(context, name, Collections.singletonList(parameters.length),
db -> Arrays.stream(parameters).forEachOrdered(db::put));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No 5 line ternary operators. An if statement with multiple returns is fine.

*/
public static OnnxMl.TensorProto floatVectorBuilder(ONNXContext context, String name, SGDVector vector) {
return floatTensorBuilder(context, name, Collections.singletonList(vector.size()),
fb -> vector.forEach(vt -> fb.put(vt.index,(float) vt.value)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be nicer to put the type on fb here. (FloatBuffer fb) -> ... means it's immediately obvious what's going on, otherwise you have to go look at floatTensorBuilder to figure out the inferred type. Ditto for the rest of the times this idiom appears in this file.

DenseVector[] denseWeights = new DenseVector[weights.length];
for (int i = 0; i < denseWeights.length; i++) {
denseWeights[i] = weights[i].densify();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think denseWeights should be pulled out of this lambda and be a local variable. It feels like a lot of magic to do in the lambda. Or even just make it into a DenseSparseMatrix and pass that in to floatMatrixBuilder. It would be a little easier to see what's going on then.

@Craigacp
Copy link
Member

Could you rebase this PR to get rid of all the factorization machines commits it pulled in again?

@Craigacp Craigacp added Oracle employee This PR is from an Oracle employee squash-commits Squash the commits when merging this PR labels Oct 25, 2021
Copy link
Member

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@Craigacp Craigacp merged commit 0dea62a into oracle:main Oct 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Oracle employee This PR is from an Oracle employee squash-commits Squash the commits when merging this PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants