Skip to content

Commit

Permalink
feat: setup workflow crud tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniAkash committed Mar 19, 2024
1 parent 37dbbf5 commit 7b7b8ba
Show file tree
Hide file tree
Showing 13 changed files with 471 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,4 @@ example/*
docs/*

# Test outputs
tests/client/workflow/fixtures/export_general.yml
tests/client/workflow/export_general.yml
60 changes: 32 additions & 28 deletions src/client/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ import { StatusCode } from "clarifai-nodejs-grpc/proto/clarifai/api/status/statu
import * as fs from "fs";
import * as yaml from "js-yaml";
import { validateWorkflow } from "../workflows/validate";
import { getYamlOutputInfoProto, isSameYamlModel } from "../workflows/utils";
import { getYamlOutputInfoProto } from "../workflows/utils";
import { Model as ModelConstructor } from "./model";
import { uuid } from "uuidv4";
import { fromProtobufObject } from "from-protobuf-object";
import { fromPartialProtobufObject } from "./fromPartialProtobufObject";

type AppConfig =
| {
Expand Down Expand Up @@ -340,9 +341,10 @@ export class App extends Lister {
}): Promise<Model.AsObject> {
const request = new PostModelsRequest();
request.setUserAppId(this.userAppId);
const newModel = new Model();
newModel.setId(modelId);
mapParamsToRequest(params, newModel);
const newModel = fromPartialProtobufObject(Model, {
id: modelId,
...params,
});
request.setModelsList([newModel]);
const postModels = promisifyGrpcCall(
this.STUB.client.postModels,
Expand Down Expand Up @@ -408,16 +410,18 @@ export class App extends Lister {

// Get all model objects from the workflow nodes.
const allModels: Model.AsObject[] = [];
let modelObject: Model.AsObject | undefined;
for (const node of workflow["nodes"]) {
for (const node of workflow.nodes) {
let modelObject: Model.AsObject | undefined;
const outputInfo = getYamlOutputInfoProto(node?.model?.outputInfo ?? {});
try {
const model = await this.model({
modelId: node.model.modelId,
modelVersionId: node.model.modelVersionId ?? "",
});
modelObject = model;
if (model) allModels.push(model);
if (model) {
allModels.push(model);
}
} catch (e) {
// model doesn't exist, create a new model from yaml config
if (
Expand Down Expand Up @@ -450,27 +454,27 @@ export class App extends Lister {
}

// If the model version ID is specified, or if the yaml model is the same as the one in the api
if (
(node.model.modelVersionId ?? "") ||
(modelObject && isSameYamlModel(modelObject, node.model))
) {
allModels.push(modelObject!);
} else if (modelObject && outputInfo) {
const model = new ModelConstructor({
modelId: modelObject.id,
authConfig: {
pat: this.pat,
appId: this.userAppId.getAppId(),
userId: this.userAppId.getUserId(),
},
});
const modelVersion = await model.createVersion({
outputInfo: outputInfo.toObject(),
});
if (modelVersion.model) {
allModels.push(modelVersion.model);
}
}
// if (
// (node.model.modelVersionId ?? "") ||
// (modelObject && isSameYamlModel(modelObject, node.model))
// ) {
// allModels.push(modelObject!);
// } else if (modelObject && outputInfo) {
// const model = new ModelConstructor({
// modelId: modelObject.id,
// authConfig: {
// pat: this.pat,
// appId: this.userAppId.getAppId(),
// userId: this.userAppId.getUserId(),
// },
// });
// const modelVersion = await model.createVersion({
// outputInfo: outputInfo.toObject(),
// });
// if (modelVersion.model) {
// allModels.push(modelVersion.model);
// }
// }
}

// Convert nodes to resources_pb2.WorkflowNodes.
Expand Down
235 changes: 235 additions & 0 deletions src/client/fromPartialProtobufObject.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import { Message, Map as ProtobufMap } from "google-protobuf";

type MessageConstructor<T extends Message> = new () => T;

type AsObject<T extends Message> = ReturnType<T["toObject"]>;

const enum PREFIX {
SET = "set",
GET = "get",
CLEAR = "clear",
}

/**
* Based on "from-protobuf-object" package - except accepts data with Partial keys
*/
export function fromPartialProtobufObject<T extends Message>(
MessageType: MessageConstructor<T>,
data: Partial<AsObject<T>>,
): T {
const instance = new MessageType();
validateMissingProps(instance, data);
for (const [prop, value] of Object.entries(
filterExtraProps(instance, data),
)) {
if (Array.isArray(value) && isProtobufMap(instance, prop)) {
const mapMethod = getMethod(prop, PREFIX.GET);
const map = callMethod(instance, mapMethod) as ProtobufMap<
unknown,
unknown
>;
const NestedType = retrieveNestedMapTypePatch(instance, prop);
for (const [k, v] of value) {
if (!isObject(v, prop)) {
map.set(k, v);
continue;
}
if (!NestedType) {
throw new Error("Unable to retrieve nested type");
}
map.set(k, fromPartialProtobufObject(NestedType, v));
}
continue;
}
const result = getResult(instance, prop, value);
validateType(instance, prop, value);
const setter = getMethod(prop, PREFIX.SET);
callMethod(instance, setter, result);
}
return instance;
}

function getResult<T extends Message>(
instance: T,
prop: string,
value: unknown,
): unknown {
if (value instanceof Uint8Array) {
return value;
}
if (Array.isArray(value)) {
if (value.length === 0 || !isArrayOfObjects(value, prop)) {
return value;
}
const NestedType = retrieveNestedRepeatedTypePatch(instance, prop);
if (!NestedType) {
throw new Error("Unable to retrieve nested type");
}
return value.map((child) => fromPartialProtobufObject(NestedType, child));
}
if (isObject(value, prop)) {
const NestedType = retrieveNestedTypePatch(instance, prop);
if (!NestedType) {
throw new Error("Unable to retrieve nested type");
}
return fromPartialProtobufObject(NestedType, value as object);
}
return value;
}

function callMethod<T extends object, R>(
obj: T,
key: string,
value?: unknown,
): R {
return (obj[key as keyof T] as (value: unknown) => R)(value);
}

function getProp(key: string, prefix: PREFIX): string {
const prop = key.slice(prefix.length);
return prop.slice(0, 1).toLowerCase() + prop.slice(1);
}

function getMethod(prop: string, prefix: PREFIX): string {
return `${prefix}${prop[0].toUpperCase()}${prop.slice(1)}`;
}

function getInstancePropsFromKeys(keys: string[], prefix: PREFIX): string[] {
return keys
.filter((key) => key.startsWith(prefix))
.map((key) => getProp(key, prefix));
}

function getInstanceProps<T extends Message>(instance: T): string[] {
const keys = Object.keys(Object.getPrototypeOf(instance));
const setters = getInstancePropsFromKeys(keys, PREFIX.SET);
const maps = getInstancePropsFromKeys(keys, PREFIX.CLEAR).filter((prop) =>
isProtobufMap(instance, prop),
);
return [...setters, ...maps];
}

function isProtobufMap<T extends Message>(instance: T, prop: string): boolean {
return (
callMethod(instance, getMethod(prop, PREFIX.GET)) instanceof ProtobufMap
);
}

function isOptional<T extends Message>(instance: T, prop: string): boolean {
const clearMethod = getMethod(prop, PREFIX.CLEAR);
return clearMethod in instance;
}

function validateMissingProps<T extends Message>(
instance: T,
data: Partial<AsObject<T>>,
): void {
const instanceProps = getInstanceProps(instance);
const dataProps = Object.keys(data);
for (const prop of instanceProps) {
if (!dataProps.includes(prop) && !isOptional(instance, prop)) {
// throw new Error(`Missing property '${prop}'`);
}
}
}

function filterExtraProps<T extends Message>(
instance: T,
data: Partial<AsObject<T>>,
): Partial<AsObject<T>> {
const instanceProps = getInstanceProps(instance);
return Object.fromEntries(
Object.entries(data).filter(
([key, value]) => instanceProps.includes(key) && value !== undefined,
),
) as Partial<AsObject<T>>;
}

function isObject(value: unknown, prop: string): boolean {
if (value === null) {
throw new Error(`Null value for key '${prop}'`);
}
return typeof value === "object";
}

function isArrayOfObjects(arr: unknown[], prop: string): boolean {
if (arr.every((item) => isObject(item, prop))) {
return true;
}
if (arr.every((item) => !isObject(item, prop))) {
return false;
}
throw new Error(`Mixed array for '${prop}'`);
}

function validateType<T extends Message>(
instance: T,
prop: string,
value: unknown,
): void {
const getter = getMethod(prop, PREFIX.GET);
const instanceValue = callMethod(instance, getter);
const expectedType =
instanceValue !== undefined ? typeof instanceValue : "object";
const actualType = value instanceof Uint8Array ? "string" : typeof value;
if (Array.isArray(instanceValue) && !Array.isArray(value)) {
throw new Error(
`Invalid type for '${prop}' (expected array, got '${actualType}')`,
);
}
if (!Array.isArray(instanceValue) && Array.isArray(value)) {
throw new Error(
`Invalid type for '${prop}' (expected '${expectedType}', got array)`,
);
}
if (expectedType !== actualType) {
throw new Error(
`Invalid type for '${prop}' (expected '${expectedType}', got '${actualType}')`,
);
}
}

function retrieveNestedTypePatch<T extends Message, N extends Message>(
instance: T,
prop: string,
): MessageConstructor<N> | null {
const getWrapperField = Message.getWrapperField;
let nestedType: MessageConstructor<Message> | null = null;
Message.getWrapperField = function (msg, ctor, fieldNumber, required) {
nestedType = ctor;
return getWrapperField(msg, ctor, fieldNumber, required);
};
callMethod(instance, getMethod(prop, PREFIX.GET));
Message.getWrapperField = getWrapperField;
return nestedType;
}

function retrieveNestedRepeatedTypePatch<T extends Message, N extends Message>(
instance: T,
prop: string,
): MessageConstructor<N> | null {
const getRepeatedWrapperField = Message.getRepeatedWrapperField;
let nestedType: MessageConstructor<Message> | null = null;
Message.getRepeatedWrapperField = function (msg, ctor, fieldNumber) {
nestedType = ctor;
return getRepeatedWrapperField(msg, ctor, fieldNumber);
};
callMethod(instance, getMethod(prop, PREFIX.GET));
Message.getRepeatedWrapperField = getRepeatedWrapperField;
return nestedType;
}

function retrieveNestedMapTypePatch<T extends Message, N extends Message>(
instance: T,
prop: string,
): MessageConstructor<N> | null {
const getMapField = Message.getMapField;
let nestedType: typeof Message | null = null;
Message.getMapField = function (msg, fieldNumber, noLazyCreate, valueCtor) {
nestedType = valueCtor ?? null;
return getMapField(msg, fieldNumber, noLazyCreate, valueCtor);
};
callMethod(instance, getMethod(prop, PREFIX.GET));
Message.getMapField = getMapField;
return nestedType;
}
6 changes: 3 additions & 3 deletions src/client/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import {
JavaScriptValue,
Struct,
} from "google-protobuf/google/protobuf/struct_pb";
import { fromPartialProtobufObject } from "./fromPartialProtobufObject";

export class Model extends Lister {
// @ts-expect-error - Variable yet to be used
Expand Down Expand Up @@ -369,8 +370,7 @@ export class Model extends Lister {
const request = new PostModelVersionsRequest();
request.setUserAppId(this.userAppId);
request.setModelId(this.id);
const modelVersion = new ModelVersion();
mapParamsToRequest(args, modelVersion);
const modelVersion = fromPartialProtobufObject(ModelVersion, args);
request.setModelVersionsList([modelVersion]);

const postModelVersions = promisifyGrpcCall(
Expand All @@ -383,7 +383,7 @@ export class Model extends Lister {
const responseObject = response.toObject();

if (responseObject.status?.code !== StatusCode.SUCCESS) {
throw new Error(responseObject.status?.toString());
throw new Error(responseObject.status?.description);
}

return responseObject;
Expand Down
17 changes: 17 additions & 0 deletions tests/client/workflow/fixtures/general.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
workflow:
id: General
nodes:
- id: general-v1.5-concept
model:
modelId: general-image-recognition
modelVersionId: aa7f35c01e0642fda5cf400f543e7c40
- id: general-v1.5-embed
model:
modelId: general-image-embedding
modelVersionId: bb186755eda04f9cbb6fe32e816be104
- id: general-v1.5-cluster
model:
modelId: general-clusterering
modelVersionId: cc2074cff6dc4c02b6f4e1b8606dcb54
nodeInputs:
- nodeId: general-v1.5-embed
11 changes: 11 additions & 0 deletions tests/client/workflow/fixtures/multi_branch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
workflow:
id: test-mb
nodes:
- id: detector
model:
modelId: face-detection
modelVersionId: 45fb9a671625463fa646c3523a3087d5
- id: moderation
model:
modelId: moderation-recognition
modelVersionId: 7cde8e92896340b0998b8260d47f1502
Loading

0 comments on commit 7b7b8ba

Please sign in to comment.