Skip to content

Commit

Permalink
[WebNN] Support MLOperandDataType to align with latest WebNN API Spec (
Browse files Browse the repository at this point in the history
…web-platform-tests#42419)

* [WebNN] Support MLOperandDataType to align with latest WebNN API Spec

* Duplicate "type" key with "dataType“ key
  • Loading branch information
BruceDai authored Nov 14, 2023
1 parent d0d3b96 commit 205a73d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions webnn/idlharness.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ idl_test(
}

self.builder = new MLGraphBuilder(self.context);
self.input = builder.input('input', {type: 'float32', dimensions: [1, 1, 5, 5]});
self.filter = builder.constant({type: 'float32', dimensions: [1, 1, 3, 3]}, new Float32Array(9).fill(1));
self.input = builder.input('input', {dataType: 'float32', dimensions: [1, 1, 5, 5]});
self.filter = builder.constant({dataType: 'float32', dimensions: [1, 1, 3, 3]}, new Float32Array(9).fill(1));
self.relu = builder.relu();
self.output = builder.conv2d(input, filter, {activation: relu, inputLayout: "nchw"});

Expand Down
20 changes: 10 additions & 10 deletions webnn/resources/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

const ExecutionArray = ['sync', 'async'];

// https://webmachinelearning.github.io/webnn/#enumdef-mloperandtype
// https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype
const TypedArrayDict = {
float32: Float32Array,
int32: Int32Array,
Expand Down Expand Up @@ -349,7 +349,7 @@ const getPrecisonTolerance = (operationName, metricType, resources) => {
* @param {Number} value
* @param {String} dataType - A data type string, like "float32", "float16",
* more types, please see:
* https://webmachinelearning.github.io/webnn/#enumdef-mloperandtype
* https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype
* @return {Number} A 64-bit signed integer.
*/
const getBitwise = (value, dataType) => {
Expand All @@ -375,7 +375,7 @@ const getBitwise = (value, dataType) => {
* @param {Number} nulp - A BigInt value indicates acceptable ULP distance.
* @param {String} dataType - A data type string, value: "float32",
* more types, please see:
* https://webmachinelearning.github.io/webnn/#enumdef-mloperandtype
* https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype
* @param {String} description - Description of the condition being tested.
*/
const assert_array_approx_equals_ulp = (actual, expected, nulp, dataType, description) => {
Expand Down Expand Up @@ -408,7 +408,7 @@ const assert_array_approx_equals_ulp = (actual, expected, nulp, dataType, descri
* @param {Number} tolerance
* @param {String} operandType - An operand type string, value: "float32",
* more types, please see:
* https://webmachinelearning.github.io/webnn/#enumdef-mloperandtype
* https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype
* @param {String} metricType - Value: 'ULP', 'ATOL'
*/
const doAssert = (operationName, actual, expected, tolerance, operandType, metricType) => {
Expand Down Expand Up @@ -465,7 +465,7 @@ const checkResults = (operationName, namedOutputOperands, outputs, resources) =>
*/
const createConstantOperand = (builder, resources) => {
const bufferView = new TypedArrayDict[resources.type](resources.data);
return builder.constant({type: resources.type, dimensions: resources.shape}, bufferView);
return builder.constant({dataType: resources.type, type: resources.type, dimensions: resources.shape}, bufferView);
};

/**
Expand All @@ -478,7 +478,7 @@ const createConstantOperand = (builder, resources) => {
const createSingleInputOperand = (builder, resources, inputOperandName) => {
inputOperandName = inputOperandName ? inputOperandName : Object.keys(resources.inputs)[0];
const inputResources = resources.inputs[inputOperandName];
return builder.input(inputOperandName, {type: inputResources.type, dimensions: inputResources.shape});
return builder.input(inputOperandName, {dataType: inputResources.type, type: inputResources.type, dimensions: inputResources.shape});
};

/**
Expand Down Expand Up @@ -525,7 +525,7 @@ const buildOperationWithSingleInput = (operationName, builder, resources) => {
* @param {Object} resources - Resources used for building a graph
* @returns {MLNamedOperands}
*/
const buildOperationWithTwoInputs= (operationName, builder, resources) => {
const buildOperationWithTwoInputs = (operationName, builder, resources) => {
// For example: MLOperand matmul(MLOperand a, MLOperand b);
const namedOutputOperand = {};
const [inputOperandA, inputOperandB] = createMultiInputOperands(builder, resources);
Expand Down Expand Up @@ -561,7 +561,7 @@ const buildConcat = (operationName, builder, resources) => {
const namedOutputOperand = {};
const inputOperands = [];
for (let input of resources.inputs) {
inputOperands.push(builder.input(input.name, {type: input.type, dimensions: input.shape}));
inputOperands.push(builder.input(input.name, {dataType: input.type, type: input.type, dimensions: input.shape}));
}
// invoke builder.concat()
namedOutputOperand[resources.expected.name] = builder[operationName](inputOperands, resources.axis);
Expand All @@ -583,7 +583,7 @@ const buildConvTranspose2d = (operationName, builder, resources) => {
return namedOutputOperand;
};

const buildConv2d= (operationName, builder, resources) => {
const buildConv2d = (operationName, builder, resources) => {
// MLOperand conv2d(MLOperand input, MLOperand filter, optional MLConv2dOptions options = {});
const namedOutputOperand = {};
const [inputOperand, filterOperand] = createMultiInputOperands(builder, resources);
Expand All @@ -598,7 +598,7 @@ const buildConv2d= (operationName, builder, resources) => {
return namedOutputOperand;
};

const buildGemm= (operationName, builder, resources) => {
const buildGemm = (operationName, builder, resources) => {
// MLOperand gemm(MLOperand a, MLOperand b, optional MLGemmOptions options = {});
const namedOutputOperand = {};
const [inputOperandA, inputOperandB] = createMultiInputOperands(builder, resources);
Expand Down

0 comments on commit 205a73d

Please sign in to comment.