Skip to content

Commit

Permalink
BAI-1459 convert modelscan class to client service
Browse files Browse the repository at this point in the history
  • Loading branch information
PE39806 committed Nov 25, 2024
1 parent f99cb65 commit 8c4c160
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 153 deletions.
93 changes: 93 additions & 0 deletions backend/src/clients/modelScan.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import fetch, { Response } from 'node-fetch'

import config from '../utils/config.js'
import { BadReq, InternalError } from '../utils/error.js'

interface ModelScanInfoResponse {
apiName: string
apiVersion: string
scannerName: string
modelscanVersion: string
}

interface ModelScanResponse {
summary: {
total_issues: number
total_issues_by_severity: {
LOW: number
MEDIUM: number
HIGH: number
CRITICAL: number
}
input_path: string
absolute_path: string
modelscan_version: string
timestamp: string
scanned: {
total_scanned: number
scanned_files: string[]
}
skipped: {
total_skipped: number
skipped_files: string[]
}
}
issues: [
{
description: string
operator: string
module: string
source: string
scanner: string
severity: string
},
]
// TODO: currently unknown what this might look like
errors: object[]
}

export async function getModelScanInfo() {
const url = `${config.avScanning.modelscan.protocol}://${config.avScanning.modelscan.host}:${config.avScanning.modelscan.port}`
let res: Response

try {
res = await fetch(`${url}/info`, {
method: 'GET',
headers: { 'Content-Type': 'application/json' },
})
} catch (err) {
throw InternalError('Unable to communicate with the ModelScan service.', { err })
}
if (!res.ok) {
throw BadReq('Unrecognised response returned by the ModelScan service.')
}

return (await res.json()) as ModelScanInfoResponse
}

export async function scanFile(file: Blob, file_name: string) {
const url = `${config.avScanning.modelscan.protocol}://${config.avScanning.modelscan.host}:${config.avScanning.modelscan.port}`
let res: Response

try {
const formData = new FormData()
formData.append('in_file', file, file_name)

res = await fetch(`${url}/scan/file`, {
method: 'POST',
headers: {
accept: 'application/json',
},
body: formData,
})
} catch (err) {
throw InternalError('Unable to communicate with the ModelScan service.', { err })
}
if (!res.ok) {
throw BadReq('Unrecognised response returned by the ModelScan service.', {
body: JSON.stringify(await res.json()),
})
}

return (await res.json()) as ModelScanResponse
}
2 changes: 1 addition & 1 deletion backend/src/connectors/fileScanning/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export function runFileScanners(cache = true) {
case FileScanKind.ModelScan:
try {
const scanner = new ModelScanFileScanningConnector()
await scanner.init()
await scanner.ping()
fileScanConnectors.push(scanner)
} catch (error) {
throw ConfigurationError('Could not configure or initialise ModelScan')
Expand Down
174 changes: 22 additions & 152 deletions backend/src/connectors/fileScanning/modelScan.ts
Original file line number Diff line number Diff line change
@@ -1,153 +1,16 @@
import fetch, { Response } from 'node-fetch'
import { Response } from 'node-fetch'
import { Readable } from 'stream'

import { getModelScanInfo, scanFile } from '../../clients/modelScan.js'
import { getObjectStream } from '../../clients/s3.js'
import { FileInterfaceDoc, ScanState } from '../../models/File.js'
import log from '../../services/log.js'
import config from '../../utils/config.js'
import { BadReq, ConfigurationError, InternalError } from '../../utils/error.js'
import { ConfigurationError } from '../../utils/error.js'
import { BaseFileScanningConnector, FileScanResult } from './Base.js'

let av: NodeModelScanAPI
export const modelScanToolName = 'ModelScan'

interface ModelScanOptions {
host: string
port: number
protocol: string
}

type ModelScanInfoResponse = {
apiName: string
apiVersion: string
scannerName: string
modelscanVersion: string
}

type ModelScanResponse = {
summary: {
total_issues: number
total_issues_by_severity: {
LOW: number
MEDIUM: number
HIGH: number
CRITICAL: number
}
input_path: string
absolute_path: string
modelscan_version: string
timestamp: string
scanned: {
total_scanned: number
scanned_files: string[]
}
skipped: {
total_skipped: number
skipped_files: string[]
}
}
issues: [
{
description: string
operator: string
module: string
source: string
scanner: string
severity: string
},
]
// TODO: currently unknown what this might look like
errors: object[]
}

class NodeModelScanAPI {
initialised: boolean
settings!: ModelScanOptions
url!: string

constructor() {
this.initialised = false
}

async init(options: ModelScanOptions): Promise<this> {
if (this.initialised) return this

this.settings = options
this.url = `${this.settings.protocol}://${this.settings.host}:${this.settings.port}`

// ping to check that the service is running
return this._getInfo().then((_) => {
this.initialised = true
return this
})
}

// TODO: try and convert this to work with a stream
async scanFile(file: Blob, file_name: string): Promise<{ isInfected: boolean; viruses: string[] }> {
if (!this.initialised)
throw ConfigurationError('NodeModelScanAPI has not been initialised.', { NodeModelScanAPI: this })

return this._postScanFile(file, file_name).then((json) => {
// map modelscan result to our format
const issues = json.summary.total_issues
const isInfected = issues > 0
const viruses: string[] = []
if (isInfected) {
for (const issue of json.issues) {
viruses.push(`${issue.severity}: ${issue.description}. ${issue.scanner}`)
}
}
return { isInfected, viruses }
})
}

async _getInfo(): Promise<unknown> {
// hit the /info endpoint
let res: Response

try {
res = await fetch(`${this.url}/info`, {
method: 'GET',
headers: { 'Content-Type': 'application/json' },
})
} catch (err) {
throw InternalError('Unable to communicate with the inferencing service.', { err })
}
if (!res.ok) {
throw BadReq('Unrecognised response returned by the inferencing service.')
}

return (await res.json()) as ModelScanInfoResponse
}

async _postScanFile(file: Blob, file_name: string): Promise<ModelScanResponse> {
// hit the /scan/file endpoint
let res: Response

try {
const formData = new FormData()
formData.append('in_file', file, file_name)

res = await fetch(`${this.url}/scan/file`, {
method: 'POST',
headers: {
accept: 'application/json',
},
body: formData,
})
} catch (err) {
throw InternalError('Unable to communicate with the inferencing service.', { err })
}
if (!res.ok) {
throw BadReq('Unrecognised response returned by the inferencing service.', {
body: JSON.stringify(await res.json()),
})
}

return (await res.json()) as ModelScanResponse
}
}

export class ModelScanFileScanningConnector extends BaseFileScanningConnector {
constructor() {
super()
Expand All @@ -157,30 +20,37 @@ export class ModelScanFileScanningConnector extends BaseFileScanningConnector {
return [modelScanToolName]
}

async init() {
async ping() {
try {
av = await new NodeModelScanAPI().init(config.avScanning.modelscan)
// discard the results as we only want to know if the endpoint is reachable
await getModelScanInfo()
} catch (error) {
throw ConfigurationError('Could not scan file as ModelScan is not running.', {
modelScanConfig: config.avScanning.modelscan,
})
}
}

async scan(file: FileInterfaceDoc): Promise<FileScanResult[]> {
if (!av) {
throw ConfigurationError(
'ModelScan does not look like it is running. Check that it has been correctly initialised by calling the init function.',
'ModelScan does not look like it is running. Check that the service configuration is correct.',
{
modelScanConfig: config.avScanning.modelscan,
},
)
}
}

async scan(file: FileInterfaceDoc): Promise<FileScanResult[]> {
this.ping()

const s3Stream = (await getObjectStream(file.bucket, file.path)).Body as Readable
try {
// TODO: see if it's possible to directly send the Readable stream rather than a blob
const fileBlob = await new Response(s3Stream).blob()
const { isInfected, viruses } = await av.scanFile(fileBlob, file.name)
const scanResults = await scanFile(fileBlob, file.name)

const issues = scanResults.summary.total_issues
const isInfected = issues > 0
const viruses: string[] = []
if (isInfected) {
for (const issue of scanResults.issues) {
viruses.push(`${issue.severity}: ${issue.description}. ${issue.scanner}`)
}
}
log.info(
{ modelId: file.modelId, fileId: file._id, name: file.name, result: { isInfected, viruses } },
'Scan complete.',
Expand Down

0 comments on commit 8c4c160

Please sign in to comment.