Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream files to the ModelScan API endpoint rather than hold entire files in memory #1663

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backend/src/clients/modelScan.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import fetch, { Response } from 'node-fetch'
import { Readable } from 'stream'

import config from '../utils/config.js'
import { BadReq, InternalError } from '../utils/error.js'
Expand Down Expand Up @@ -65,13 +66,21 @@ export async function getModelScanInfo() {
return (await res.json()) as ModelScanInfoResponse
}

export async function scanFile(file: Blob, file_name: string) {
export async function scanStream(stream: Readable, file_name: string, file_size: number) {
PE39806 marked this conversation as resolved.
Show resolved Hide resolved
const url = `${config.avScanning.modelscan.protocol}://${config.avScanning.modelscan.host}:${config.avScanning.modelscan.port}`
let res: Response

try {
const formData = new FormData()
formData.append('in_file', file, file_name)
formData.append(
'in_file',
{
[Symbol.toStringTag]: 'File',
size: file_size,
stream: () => stream,
},
file_name,
)

res = await fetch(`${url}/scan/file`, {
method: 'POST',
Expand Down
7 changes: 2 additions & 5 deletions backend/src/connectors/fileScanning/modelScan.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Response } from 'node-fetch'
import { Readable } from 'stream'

import { getModelScanInfo, scanFile } from '../../clients/modelScan.js'
import { getModelScanInfo, scanStream } from '../../clients/modelScan.js'
import { getObjectStream } from '../../clients/s3.js'
import { FileInterfaceDoc, ScanState } from '../../models/File.js'
import log from '../../services/log.js'
Expand Down Expand Up @@ -39,9 +38,7 @@ export class ModelScanFileScanningConnector extends BaseFileScanningConnector {

const s3Stream = (await getObjectStream(file.bucket, file.path)).Body as Readable
try {
// TODO: see if it's possible to directly send the Readable stream rather than a blob
const fileBlob = await new Response(s3Stream).blob()
const scanResults = await scanFile(fileBlob, file.name)
const scanResults = await scanStream(s3Stream, file.name, file.size)

const issues = scanResults.summary.total_issues
const isInfected = issues > 0
Expand Down
17 changes: 12 additions & 5 deletions backend/test/clients/__snapshots__/modelScan.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ exports[`clients > modelScan > getModelScanInfo > success 1`] = `
]
`;

exports[`clients > modelScan > scanFile > success 1`] = `
exports[`clients > modelScan > scanStream > success 1`] = `
[
[
"undefined://undefined:undefined/scan/file",
Expand All @@ -23,10 +23,17 @@ exports[`clients > modelScan > scanFile > success 1`] = `
Symbol(state): [
{
"name": "in_file",
"value": File {
Symbol(kHandle): Blob {},
Symbol(kLength): 0,
Symbol(kType): "application/x-hdf5",
"value": FileLike {
Symbol(state): {
"blobLike": {
"size": 0,
"stream": [Function],
Symbol(Symbol.toStringTag): "File",
},
"lastModified": 0,
"name": "safe_model.h5",
"type": undefined,
},
},
},
],
Expand Down
19 changes: 12 additions & 7 deletions backend/test/clients/modelScan.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { PassThrough } from 'stream'
import { describe, expect, test, vi } from 'vitest'

import { getModelScanInfo, scanFile } from '../../src/clients/modelScan.js'
import { getModelScanInfo, scanStream } from '../../src/clients/modelScan.js'

const configMock = vi.hoisted(() => ({
avScanning: {
Expand Down Expand Up @@ -59,7 +60,7 @@ describe('clients > modelScan', () => {
expect(() => getModelScanInfo()).rejects.toThrowError(/^Unable to communicate with the ModelScan service./)
})

test('scanFile > success', async () => {
test('scanStream > success', async () => {
const expectedResponse = {
summary: {
total_issues: 0,
Expand Down Expand Up @@ -90,28 +91,32 @@ describe('clients > modelScan', () => {
text: vi.fn(),
json: vi.fn(() => expectedResponse),
})
const response = await scanFile(new Blob([''], { type: 'application/x-hdf5' }), 'safe_model.h5')
// force lastModified to be 0
const date = new Date(1970, 0, 1, 0)
vi.setSystemTime(date)

const response = await scanStream(new PassThrough(), 'safe_model.h5', 0)

expect(fetchMock.default).toBeCalled()
expect(fetchMock.default.mock.calls).toMatchSnapshot()
expect(response).toStrictEqual(expectedResponse)
})

test('scanFile > bad response', async () => {
test('scanStream > bad response', async () => {
fetchMock.default.mockResolvedValueOnce({
ok: false,
text: vi.fn(() => 'Unrecognised response'),
json: vi.fn(),
})
expect(() => scanFile(new Blob([''], { type: 'application/x-hdf5' }), 'safe_model.h5')).rejects.toThrowError(
expect(() => scanStream(new PassThrough(), 'safe_model.h5', 0)).rejects.toThrowError(
/^Unrecognised response returned by the ModelScan service./,
)
})

test('scanFile > rejected', async () => {
test('scanStream > rejected', async () => {
fetchMock.default.mockRejectedValueOnce('Unable to communicate with the inferencing service.')

expect(() => scanFile(new Blob([''], { type: 'application/x-hdf5' }), 'safe_model.h5')).rejects.toThrowError(
expect(() => scanStream(new PassThrough(), 'safe_model.h5', 0)).rejects.toThrowError(
/^Unable to communicate with the ModelScan service./,
)
})
Expand Down
Loading