From 7c4ae63699d24a8cd6c8087f43d83a86d7053739 Mon Sep 17 00:00:00 2001 From: dev-737 <73829355+dev-737@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:44:17 +0530 Subject: [PATCH] fix: improved nsfw image detection capabilities --- package.json | 29 ++++++++++++++++------------- src/api/routes/nsfw.ts | 33 ++++++++++++++++++++++++++++----- src/index.ts | 5 +---- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/package.json b/package.json index 53970378..1197f143 100644 --- a/package.json +++ b/package.json @@ -13,16 +13,16 @@ "register:commands": "node build/utils/RegisterCmdCli.js", "release": "standard-version", "lint": "eslint --cache --fix .", - "prepare": "husky install" + "prepare": "husky" }, "engines": { "node": ">=18.0.0" }, "type": "module", "dependencies": { - "@prisma/client": "^5.8.1", - "@sentry/node": "^7.98.0", - "@tensorflow/tfjs-node": "^4.16.0", + "@prisma/client": "^5.9.1", + "@sentry/node": "^7.99.0", + "@tensorflow/tfjs": "3.18.0", "@top-gg/sdk": "^3.1.6", "common-tags": "^1.8.2", "discord-hybrid-sharding": "^2.1.4", @@ -31,31 +31,34 @@ "express": "^4.18.2", "google-translate-api-x": "^10.6.8", "i18n": "^0.15.1", + "jpeg-js": "^0.4.4", + "js-yaml": "^4.1.0", "lodash": "^4.17.21", "lz-string": "^1.5.0", "nsfwjs": "^2.4.2", "parse-duration": "^1.1.0", + "sharp": "^0.33.2", "source-map-support": "^0.5.21", - "winston": "^3.11.0", - "yaml": "^2.3.4" + "winston": "^3.11.0" }, "devDependencies": { "@types/common-tags": "^1.8.4", "@types/express": "^4.17.21", "@types/i18n": "^0.13.10", - "@types/jest": "^29.5.11", + "@types/jest": "^29.5.12", + "@types/js-yaml": "^4.0.9", "@types/lodash": "^4.14.202", - "@types/node": "^20.11.10", + "@types/node": "^20.11.16", "@types/source-map-support": "^0.5.10", - "@typescript-eslint/eslint-plugin": "^6.19.1", - "@typescript-eslint/parser": "^6.19.1", + "@typescript-eslint/eslint-plugin": "^6.20.0", + "@typescript-eslint/parser": "^6.20.0", "cz-conventional-changelog": "^3.3.0", "eslint": "^8.56.0", - "husky": "^9.0.6", + "husky": "^9.0.10", "jest": "^29.7.0", - "lint-staged": "^15.2.0", + "lint-staged": "^15.2.1", "prettier": "^3.2.4", - "prisma": "^5.8.1", + "prisma": "^5.9.1", "standard-version": "^9.5.0", "ts-jest": "^29.1.2", "tsc-watch": "^6.0.4", diff --git a/src/api/routes/nsfw.ts b/src/api/routes/nsfw.ts index 78cf07ec..7c476a98 100644 --- a/src/api/routes/nsfw.ts +++ b/src/api/routes/nsfw.ts @@ -1,13 +1,37 @@ import { load } from 'nsfwjs'; import { Router } from 'express'; import { captureException } from '@sentry/node'; -import { node } from '@tensorflow/tfjs-node'; +import { tensor3d, enableProdMode } from '@tensorflow/tfjs'; import Logger from '../../utils/Logger.js'; +import sharp from 'sharp'; +import jpeg from 'jpeg-js'; -const nsfwModel = await load(); +// disable tfjs logs +enableProdMode(); + +const nsfwModel = await load('https://nsfwjs.com/quant_nsfw_mobilenet/'); const router: Router = Router(); +const imageToTensor = async (rawImageData: ArrayBuffer) => { + rawImageData = await sharp(rawImageData).raw().jpeg().toBuffer(); + const decoded = jpeg.decode(rawImageData); // This is key for the prediction to work well + const { width, height, data } = decoded; + const buffer = new Uint8Array(width * height * 3); + let offset = 0; + for (let i = 0; i < buffer.length; i += 3) { + buffer[i] = data[offset]; + buffer[i + 1] = data[offset + 1]; + buffer[i + 2] = data[offset + 2]; + + offset += 4; + } + + return tensor3d(buffer, [height, width, 3]); +}; + + router.get('/nsfw', async (req, res) => { + const url = new URL(req.url, `http://${req.headers.host}`); const imageUrl = url.searchParams.get('url'); @@ -30,9 +54,8 @@ router.get('/nsfw', async (req, res) => { try { const imageBuffer = await (await fetch(imageUrl)).arrayBuffer(); - const imageTensor = (await node.decodeImage(Buffer.from(imageBuffer), 3)) as any; // eslint-disable-line @typescript-eslint/no-explicit-any + const imageTensor = await imageToTensor(imageBuffer); const predictions = await nsfwModel.classify(imageTensor); - imageTensor.dispose(); res.writeHead(200, { 'Content-Type': 'application/json' }); res.end(JSON.stringify(predictions)); @@ -45,4 +68,4 @@ router.get('/nsfw', async (req, res) => { } }); -export default router; \ No newline at end of file +export default router; diff --git a/src/index.ts b/src/index.ts index d45dc2ac..b2487b6e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -38,12 +38,9 @@ clusterManager.on('clusterCreate', async (cluster) => { deleteOldMessages(); deleteExpiredInvites(); - // update top.gg stats every 10 minutes - scheduler.addRecurringTask('syncBotlistStats', 60 * 10_000, () => syncBotlistStats(clusterManager)); - // delete expired invites every 1 hour scheduler.addRecurringTask('deleteExpiredInvites', 60 * 60 * 1_000, deleteExpiredInvites); - // delete old network messages every 12 hours scheduler.addRecurringTask('deleteOldMessages', 60 * 60 * 12_000, deleteOldMessages); + scheduler.addRecurringTask('syncBotlistStats', 60 * 10_000, () => syncBotlistStats(clusterManager)); } });