-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor onnx math #187
Refactor onnx math #187
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a nice tidy up. I've got a few comments about the typing and things which will make it easier (for me) to figure out what's going on.
@@ -307,7 +307,7 @@ private ONNXOperators(String value, int numInputs, int numOptionalInputs, int nu | |||
for (String o : outputs) { | |||
nodeBuilder.addOutput(o); | |||
} | |||
nodeBuilder.setName(context.generateUniqueName(opName)); | |||
nodeBuilder.setName(context.generateUniqueName(opName) + ":" + outputs[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX names are required by the spec to be alphanumeric + underscores.
return OnnxMl.TensorProto.newBuilder() | ||
.setName(context.generateUniqueName(name)) | ||
.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()) | ||
.addAllDims(() -> dims.stream().map(Integer::longValue).iterator()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd prefer this to just collect(Collectors.toList())
rather than make a lambda to fabricate the iterable. Plus in 17 we can replace it with .toList()
which is shorter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or even just make it accept List<Long>
as that's what ONNX is expecting, and the upcast on the calling side when creating the list might be done automatically?
? floatTensorBuilder(context, name, Collections.singletonList(parameters.length), | ||
fb -> Arrays.stream(parameters).forEachOrdered(d -> fb.put((float)d))) | ||
: doubleTensorBuilder(context, name, Collections.singletonList(parameters.length), | ||
db -> Arrays.stream(parameters).forEachOrdered(db::put)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No 5 line ternary operators. An if statement with multiple returns is fine.
*/ | ||
public static OnnxMl.TensorProto floatVectorBuilder(ONNXContext context, String name, SGDVector vector) { | ||
return floatTensorBuilder(context, name, Collections.singletonList(vector.size()), | ||
fb -> vector.forEach(vt -> fb.put(vt.index,(float) vt.value))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be nicer to put the type on fb
here. (FloatBuffer fb) -> ...
means it's immediately obvious what's going on, otherwise you have to go look at floatTensorBuilder
to figure out the inferred type. Ditto for the rest of the times this idiom appears in this file.
DenseVector[] denseWeights = new DenseVector[weights.length]; | ||
for (int i = 0; i < denseWeights.length; i++) { | ||
denseWeights[i] = weights[i].densify(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think denseWeights
should be pulled out of this lambda and be a local variable. It feels like a lot of magic to do in the lambda. Or even just make it into a DenseSparseMatrix and pass that in to floatMatrixBuilder
. It would be a little easier to see what's going on then.
Could you rebase this PR to get rid of all the factorization machines commits it pulled in again? |
8800de5
to
56db2fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Description
As we add
ONNXExportable
to new models, there is a bunch of duplicated logic on creatingOnnxML.TensorProto
instances. This PR creates (hopefully) fully general methods for generating TensorProtos and replaces existing implementations with them.