Skip to content

Commit

Permalink
fix: use static predicate if class instanceof check fails in validate…
Browse files Browse the repository at this point in the history
…Args
  • Loading branch information
AlCalzone committed Nov 14, 2024
1 parent b7b2a82 commit 9c8ba90
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 25 deletions.
2 changes: 1 addition & 1 deletion packages/cc/src/cc/NotificationCC.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ export class NotificationCCReport extends NotificationCC {
}
} else if (isUint8Array(this.eventParameters)) {
message["event parameters"] = buffer2hex(this.eventParameters);
} else if (this.eventParameters instanceof Duration) {
} else if (Duration.isDuration(this.eventParameters)) {
message["event parameters"] = this.eventParameters.toString();
} else {
message["event parameters"] = Object.entries(
Expand Down
2 changes: 1 addition & 1 deletion packages/cc/src/cc/SceneActuatorConfigurationCC.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export class SceneActuatorConfigurationCCAPI extends CCAPI {
);
return this.set(propertyKey, dimmingDuration, value);
} else if (property === "dimmingDuration") {
if (typeof value !== "string" && !(value instanceof Duration)) {
if (typeof value !== "string" && !Duration.isDuration(value)) {
throwWrongValueType(
this.ccId,
property,
Expand Down
2 changes: 1 addition & 1 deletion packages/cc/src/cc/SceneControllerConfigurationCC.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export class SceneControllerConfigurationCCAPI extends CCAPI {
return this.set(propertyKey, value, dimmingDuration);
}
} else if (property === "dimmingDuration") {
if (typeof value !== "string" && !(value instanceof Duration)) {
if (typeof value !== "string" && !Duration.isDuration(value)) {
throwWrongValueType(
this.ccId,
property,
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/values/Cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export function serializeCacheValue(value: unknown): SerializedValue {
),
[SPECIAL_TYPE_KEY]: "map",
};
} else if (value instanceof Duration) {
} else if (Duration.isDuration(value)) {
const valueAsJSON = value.toJSON();
return {
...(typeof valueAsJSON === "string"
Expand Down
11 changes: 10 additions & 1 deletion packages/core/src/values/Duration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ export class Duration {
return new Duration(0, "default");
}

public static isDuration(value: any): value is Duration {
return typeof value === "object"
&& value != null
&& "value" in value
&& typeof value.value === "number"
&& "unit" in value
&& typeof value.unit === "string";
}

/** Parses a duration as represented in Report commands */
public static parseReport(payload?: number): Duration | undefined {
if (payload == undefined) return undefined;
Expand Down Expand Up @@ -97,7 +106,7 @@ export class Duration {
public static from(input?: Duration | string): Duration | undefined;

public static from(input?: Duration | string): Duration | undefined {
if (input instanceof Duration) {
if (Duration.isDuration(input)) {
return input;
} else if (input) {
return Duration.parseString(input);
Expand Down
9 changes: 9 additions & 0 deletions packages/core/src/values/Timeout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ export class Timeout {
this._value = clamp(v, 0, this.unit === "seconds" ? 60 : 191);
}

public static isTimeout(value: any): value is Timeout {
return typeof value === "object"
&& value != null
&& "value" in value
&& typeof value.value === "number"
&& "unit" in value
&& typeof value.unit === "string";
}

/** Parses a timeout as represented in Report commands */
public static parse(payload: number): Timeout;
public static parse(payload: undefined): undefined;
Expand Down
76 changes: 63 additions & 13 deletions packages/transformers/src/validateArgs/visitor-type-check.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,79 @@ function visitClassType(type: ts.ObjectType, visitorContext: VisitorContext) {
}
}

const makeClass = () =>
importPath
// Create an import for classes in other files
// require("./bar").Foo
? f.createPropertyAccessExpression(
f.createCallExpression(
f.createIdentifier("require"),
undefined,
[f.createStringLiteral(importPath)],
),
f.createIdentifier(identifier),
)
// Foo
: f.createIdentifier(identifier);
const makeInstanceofCheck = () =>
f.createBinaryExpression(
// foo
VisitorUtils.objectIdentifier,
// instanceof
f.createToken(ts.SyntaxKind.InstanceOfKeyword),
// require("./bar").Foo
makeClass(),
);
// require("./bar").Foo.isFoo
const makePredicate = () =>
f.createPropertyAccessExpression(
makeClass(),
f.createIdentifier(`is${identifier}`),
);
// typeof require("./bar").Foo.isFoo === "function"
const makeFunctionCheck = () =>
f.createBinaryExpression(
f.createTypeOfExpression(makePredicate()),
f.createToken(ts.SyntaxKind.EqualsEqualsEqualsToken),
f.createStringLiteral("function"),
);
// require("./bar").Foo.isFoo(foo)
const makePredicateCall = () =>
f.createCallExpression(
makePredicate(),
undefined,
[VisitorUtils.objectIdentifier],
);

return VisitorUtils.setFunctionIfNotExists(name, visitorContext, () => {
return VisitorUtils.createAssertionFunction(
// !(foo instanceof require("./bar").Foo)
// !( ...
f.createPrefixUnaryExpression(
ts.SyntaxKind.ExclamationToken,
f.createParenthesizedExpression(
f.createBinaryExpression(
VisitorUtils.objectIdentifier,
f.createToken(ts.SyntaxKind.InstanceOfKeyword),
// Create an import for classes in other files
importPath
? f.createPropertyAccessExpression(
f.createCallExpression(
f.createIdentifier("require"),
undefined,
[f.createStringLiteral(importPath)],
// foo instanceof require("./bar").Foo
makeInstanceofCheck(),
// ||
f.createToken(ts.SyntaxKind.BarBarToken),
// (...
f.createParenthesizedExpression(
f.createBinaryExpression(
// typeof require("./bar").Foo.isFoo === "function"
makeFunctionCheck(),
// &&
f.createToken(
ts.SyntaxKind.AmpersandAmpersandToken,
),
f.createIdentifier(identifier),
)
: f.createIdentifier(identifier),
// require("./bar").Foo.isFoo(foo)
makePredicateCall(),
),
),
// ...)
),
),
),
// ...)
{ type: "class", name: identifier },
name,
visitorContext,
Expand Down
15 changes: 15 additions & 0 deletions packages/transformers/test/fixtures/_includes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,18 @@ export class FooBar {

public foo: "foo";
}

export class Baz {
constructor() {
this.baz = "baz";
}

public baz: "baz";

public static isBaz(value: any): value is Baz {
return typeof value === "object"
&& value != null
&& "baz" in value
&& value.baz === "baz";
}
}
27 changes: 22 additions & 5 deletions packages/transformers/test/fixtures/testClass.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { validateArgs } from "@zwave-js/transformers";
import assert from "node:assert";
import type { FooBar as Imported } from "./_includes";
// eslint-disable-next-line @typescript-eslint/no-var-requires
const ImportedClass = require("./_includes").FooBar;
import type { Baz, FooBar as Imported } from "./_includes";

const ImportedFooBar = require("./_includes").FooBar;

class Local {
constructor() {
Expand All @@ -12,6 +12,12 @@ class Local {
public foo: "bar";
}

// Tests for the static predicate function
// LocalBaz is structurally compatible with Baz, but not the same class
class LocalBaz {
public baz: "baz" = "baz" as const;
}

class Test {
@validateArgs()
foo(arg1: Local): void {
Expand All @@ -24,12 +30,19 @@ class Test {
arg1;
return void 0;
}

@validateArgs()
baz(arg1: Baz): void {
arg1;
return void 0;
}
}

const test = new Test();
// These should not throw
test.foo(new Local());
test.bar(new ImportedClass());
test.bar(new ImportedFooBar());
test.baz(new LocalBaz());

// These should throw
assert.throws(
Expand All @@ -42,7 +55,7 @@ assert.throws(
() => test.foo(undefined),
/arg1 is not a Local/,
);
assert.throws(() => test.foo(new ImportedClass()), /arg1 is not a Local/);
assert.throws(() => test.foo(new ImportedFooBar()), /arg1 is not a Local/);
assert.throws(
// @ts-expect-error
() => test.bar("string"),
Expand All @@ -58,3 +71,7 @@ assert.throws(
() => test.bar(new Local()),
/arg1 is not a Imported/,
);
assert.throws(
() => test.baz(new ImportedFooBar()),
/arg1 is not a Baz/,
);
2 changes: 1 addition & 1 deletion packages/zwave-js/src/lib/driver/Driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2798,7 +2798,7 @@ export class Driver extends TypedEventEmitter<DriverEventCallbacks>
if (ccArgs.parameters) {
if (isUint8Array(ccArgs.parameters)) {
msg.parameters = buffer2hex(ccArgs.parameters);
} else if (ccArgs.parameters instanceof Duration) {
} else if (Duration.isDuration(ccArgs.parameters)) {
msg.duration = ccArgs.parameters.toString();
} else if (isObject(ccArgs.parameters)) {
Object.assign(msg, ccArgs.parameters);
Expand Down
2 changes: 1 addition & 1 deletion packages/zwave-js/src/lib/node/Node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ export class ZWaveNode extends ZWaveNodeMixins implements QuerySecurityClasses {
public set defaultTransitionDuration(value: string | Duration | undefined) {
// Normalize to strings
if (typeof value === "string") value = Duration.from(value);
if (value instanceof Duration) value = value.toString();
if (Duration.isDuration(value)) value = value.toString();

this.driver.cacheSet(
cacheKeys.node(this.id).defaultTransitionDuration,
Expand Down

0 comments on commit 9c8ba90

Please sign in to comment.