likes
comments
collection
share

SpringBoot集成TensorFlow对图片内容进行安全检测 : 全流程指南在这篇指南中,我们将深入探讨如何使用S

作者站长头像
站长
· 阅读数 26

一、TensorFlow介绍

1. 什么是TensorFlow

TensorFlow 是由Google开发的开源机器学习框架,广泛用于构建和训练神经网络模型。它提供了灵活的工具集,支持各种机器学习和深度学习任务,包括图像分类、自然语言处理和时间序列分析。TensorFlow的强大之处在于它的跨平台特性,支持从移动设备到大型分布式系统的部署,同时有丰富的API和社区支持,适合从初学者到专家的不同需求。

2. TensorFlow的应用场景

  1. 图像识别与分类:TensorFlow最常见的应用之一是图像识别和分类。例如,TensorFlow被用于自动化检测图片中的物体,如汽车、人、动植物等。这种技术广泛应用于自动驾驶、智能监控和医疗影像分析等领域。
  2. 自然语言处理(NLP) :TensorFlow支持构建复杂的自然语言处理模型,如语音识别、机器翻译和情感分析。通过使用TensorFlow训练的模型,应用程序可以实现从文本中提取信息、生成自然语言回复、甚至进行跨语言的实时翻译。
  3. 推荐系统:许多在线平台使用TensorFlow来构建个性化推荐系统。例如,电商平台通过分析用户的浏览历史和购买行为,利用TensorFlow的深度学习模型生成个性化的商品推荐,从而提升用户体验和销售额。
  4. 时间序列预测:TensorFlow在金融、气象和交通领域用于时间序列预测。通过历史数据的学习,模型可以预测未来的趋势或事件,如股票价格走势、天气变化或交通流量预测。
  5. 生成对抗网络(GANs) :TensorFlow支持开发和训练生成对抗网络,用于生成逼真的图像、视频或音频。例如,GANs可以被用于生成虚拟人物形象、创作艺术作品或增强现实内容。
  6. 医疗诊断:在医疗领域,TensorFlow被用来开发诊断工具。例如,通过分析医学影像(如X光片或MRI),TensorFlow模型可以辅助医生早期发现病变或其他健康问题,提升诊断效率和准确性。

二、NSFW介绍

1. 什么是NSFW

NSFW(Not Safe For Work)模型是一类专门用于检测不适合在公共场所或工作环境中展示的内容的模型,通常指包含成人内容、暴力或其他不宜公开展示的图片或视频。NSFW模型通过训练深度学习算法,能够自动识别这些不适合公开的内容,广泛应用于社交媒体平台、内容审核系统以及其他需要过滤敏感内容的应用场景。

三、具体实现

看完了 TensorFlow 和 NSFW 的介绍,下面让我们来快速实现一下吧(工具类和初始化模型不懂也没关系复制粘贴即可,只需要了解大致在做什么,方法返回结果是什么,就可以进行最简单的使用

1. 引入依赖

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
</dependency>

2. 初始化NSFW模型

NSFW 模型可以通过 GitHub 获取,推荐使用 nsfwjs 项目中的模型。可以使用 Python 将模型转换为 TensorFlow 的 saved_model 格式,然后在 SpringBoot 应用中进行加载。当然,也可以直接找到已转换好的 NSFW saved_model 格式的模型进行加载。

在使用 TensorFlow API 后,加载 NSFW 模型变得非常简单。我们将模型存放在项目的 \resources\tensorflow\saved_model\nsfw 目录中,因此在加载时,只需将该目录的绝对路径传入即可轻松完成模型的初始化 , 然后通过TensorFlow的session传入对应张量就可以获取到nsfw分类比例。

/**
 * NSFW 模型
 *
 * @author : YiFei
 */
@Getter
@Component
public class NSFWModelService {

    // 提供方法来获取 TensorFlow Session
    private Session session;

    @PostConstruct
    public void init() {
        // 加载 TensorFlow 模型
        try {
            String modelAbsolutePath = new ClassPathResource("tensorflow/saved_model/nsfw").getFile().getAbsolutePath();
            SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
            this.session = model.session();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 在销毁 Bean 时关闭 TensorFlow Session
     */
    @PreDestroy
    public void closeSession() {
        this.session.close();
    }
}

完成上述操作,我们即可在本地加载模型,如果需要部署还会出现一些细节上的问题,我们将在最后进行解析。如果需要打印加载模型的详细信息,可以加上以下代码

//            以下是获取模型 Inputs 数据格式 、输入张量名  , output 数据格式 、输出张量名
//            MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(model.metaGraphDef());
//            Map<String, SignatureDef> signatureDefMap = metaGraphDef.getSignatureDefMap();
//
//            for (Map.Entry<String, SignatureDef> entry : signatureDefMap.entrySet()) {
//                System.out.println("SignatureDef key: " + entry.getKey());
//
//                SignatureDef signatureDef = entry.getValue();
//                System.out.println("Inputs:");
//                for (Map.Entry<String, TensorInfo> inputEntry : signatureDef.getInputsMap().entrySet()) {
//                    String inputKey = inputEntry.getKey();
//                    TensorInfo inputTensorInfo = inputEntry.getValue();
//
//                    // 打印输入张量的名称
//                    System.out.println("  Key: " + inputKey);
//                    System.out.println("  Name: " + inputTensorInfo.getName());
//
//                    // 打印输入张量的形状
//                    if (inputTensorInfo.hasTensorShape()) {
//                        TensorShapeProto tensorShape = inputTensorInfo.getTensorShape();
//                        System.out.println("  Shape: " + tensorShape);
//                    }
//
//                    // 打印输入张量的数据类型
//                    System.out.println("  Data Type: " + inputTensorInfo.getDtype());
//                }
//
//                System.out.println("Outputs:");
//                for (Map.Entry<String, TensorInfo> outputEntry : signatureDef.getOutputsMap().entrySet()) {
//                    String outputKey = outputEntry.getKey();
//                    TensorInfo outputTensorInfo = outputEntry.getValue();
//
//                    // 打印输出张量的名称
//                    System.out.println("  Key: " + outputKey);
//                    System.out.println("  Name: " + outputTensorInfo.getName());
//
//                    // 打印输出张量的形状
//                    if (outputTensorInfo.hasTensorShape()) {
//                        TensorShapeProto tensorShape = outputTensorInfo.getTensorShape();
//                        System.out.println("  Shape: " + tensorShape.toString());
//                    }
//
//                    // 打印输出张量的数据类型
//                    System.out.println("  Data Type: " + outputTensorInfo.getDtype());
//                }
//            }

3. 编写工具类

尽管 TensorFlow API 已经相对简洁,但在实际使用中仍可能显得繁琐。为了解决这一问题,我们可以封装一个工具类,使接口更加友好,让开发者只需一行代码即可完成图片内容安全的校验,而无需编写大量冗余的代码。通过这个工具类,您可以更轻松地集成 NSFW 模型,并提高项目的开发效率。

  • 解释NSFW分类任务,您可以设定对应概率阈值,来判断文件是否违规
    • DRAWINGS: 卡通或漫画图片
    • HENTAI: 带有情色成分的动画或漫画
    • NEUTRAL: 正常的、适合公开展示的图像
    • PORN: 色情图片
    • SEXY: 暗示性强的图像,但不完全是色情

这个工具类 NSFWAnalyzerUtils 主要用于对上传的图像文件进行 NSFW(Not Safe For Work)内容检测。它提供了一些实用的方法,简述如下:

  1. NSFW 内容预测

    • 方法 getNsfwPredictions 通过调用预训练的 NSFW 模型,返回一个拥有 NSFW 类型及其对应概率的映射。图像会被处理并转换为模型可以理解的格式,预测结果会被格式化为百分比字符串并存储在 Map 中。
  2. 文件安全性判断

    • 方法 isNsfwFile 提供了基于默认或自定义阈值判断上传文件是否包含不安全内容的功能。该方法调用模型进行推理,并根据输出结果和设定的阈值判断文件是否安全。
  3. 图像处理

    • 工具类内部还包括图像的预处理步骤,如调整图像大小、创建适用于 TensorFlow 的张量输入格式,以及将图像的像素值归一化以便模型推理。

整体上,这个工具类简化了 NSFW 模型的加载和调用流程,封装了图像预处理和安全性判断的逻辑,方便在应用中集成 NSFW 内容检测功能。

/**
 * nsfw 文件校验工具类
 *
 * @author : YiFei
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class NSFWAnalyzerUtils {

    public static final float NSFWThreshold = 0.28f;
    public static final String PERCENTAGE_FORMAT = "#.###%"; // 格式化为百分比,保留三位小数
    private final NSFWModelService nsfwModelService;

    /**
     * 对上传的文件进行 NSFW 校验,返回按概率排序的 NSFW 类型及其概率。
     *
     * @param file 上传的文件(MultipartFile)
     * @return 按概率排序的 NSFW 类型及其概率的 Map
     */
    @SneakyThrows
    public Map<String, String> getNsfwPredictions(MultipartFile file) {
        // 使用模型从文件中提取 NSFW 预测结果,这里假设 extractNSFWModelPredictions 返回一个 float[][] output
        float[][] output = extractNSFWModelPredictions(file);

        // NSFW 类型 ( "涂鸦", "色情动漫", "中性", "色情", "性感" )
        String[] nsfwTypes = {"Drawing", "Hentai", "Neutral", "Porn", "Sexy"};

        // 创建一个 Map 用于存储 NSFW 类型及其概率
        Map<String, String> nsfwProbabilities = new LinkedHashMap<>();

        // 将模型输出的概率数组中的值按顺序放入 Map 中,并格式化为百分比字符串
        DecimalFormat df = new DecimalFormat(PERCENTAGE_FORMAT);
        for (int i = 0; i < output[0].length && i < nsfwTypes.length; i++) {
            // 格式化为百分比字符串
            String formattedProbability = df.format(output[0][i]);
            nsfwProbabilities.put(nsfwTypes[i], formattedProbability);
        }

        return nsfwProbabilities;
    }

    /**
     * 对上传的文件进行 NSFW 校验,返回文件是否不安全的判断结果。
     * 使用默认的 NSFW 阈值进行判断。
     *
     * @param file 上传的文件(MultipartFile)
     * @return 如果文件被判断为不安全,则返回 true;否则返回 false
     */
    @SneakyThrows
    public boolean isNsfwFile(MultipartFile file) {
        // 使用模型从文件中提取 NSFW 预测结果
        float[][] output = extractNSFWModelPredictions(file);
        // 使用默认的 NSFW 阈值判断文件是否不安全
        return isUnsafe(output, NSFWThreshold);
    }

    /**
     * 对上传的文件进行 NSFW 校验,返回文件是否不安全的判断结果。
     *
     * @param file          上传的文件(MultipartFile)
     * @param NSFWThreshold 判断文件不安全的阈值
     * @return 如果文件被判断为不安全,则返回 true;否则返回 false
     */
    @SneakyThrows
    public boolean isNsfwFile(MultipartFile file, float NSFWThreshold) {
        // 使用模型从文件中提取 NSFW 预测结果
        float[][] output = extractNSFWModelPredictions(file);
        // 使用指定的 NSFW 阈值判断文件是否不安全
        return isUnsafe(output, NSFWThreshold);
    }

    /**
     * 从上传的文件中获取浮点数数组。
     *
     * @param file 上传的文件(MultipartFile)
     * @return 包含模型推理结果的浮点数数组,形状为 [1, 5]
     * @throws IOException 如果读取文件或处理图像时发生错误
     */
    private float[][] extractNSFWModelPredictions(MultipartFile file) throws IOException {
        // 读取上传的图像文件并进行处理
        BufferedImage image = ImageIO.read(file.getInputStream());

        if (image == null) {
            throw new IOException("Failed to read image from file. / 无法读取文件 :" + file.getOriginalFilename());
        }

        // 将图像调整大小为模型期望的尺寸
        BufferedImage resizedImage = resizeImage(image, 224, 224);

        // 创建输入张量
        Tensor<?> inputTensor = createImageTensor(resizedImage, resizedImage.getWidth(), resizedImage.getHeight());

        // 运行模型推理并获取结果
        Tensor<?> result = nsfwModelService.getSession()
                .runner()
                .feed("serving_default_input:0", inputTensor) // 使用正确的输入张量名称
                .fetch("StatefulPartitionedCall:0") // 使用正确的输出张量名称
                .run()
                .get(0);

        // 处理模型输出结果
        float[][] output = new float[1][5]; // 根据模型的输出格式调整数组大小
        result.copyTo(output);

        return output;
    }

    /**
     * 根据给定的 BufferedImage 创建对应的 TensorFlow 图像 Tensor。
     * 图像将被转换为指定的尺寸,并将像素值归一化到 [0, 1] 范围内作为 Tensor 数据。
     *
     * @param image  要转换为 Tensor 的 BufferedImage 对象
     * @param width  目标图像宽度
     * @param height 目标图像高度
     * @return 表示图像的 TensorFlow Tensor 对象
     */
    private Tensor<?> createImageTensor(BufferedImage image, int width, int height) {
        int channels = 3; // 图像通道数为 RGB

        // 创建用于存储图像像素数据的一维数组
        float[] tensorData = new float[height * width * channels];

        // 遍历图像的每个像素,并将 RGB 值归一化后存储到 tensorData 中
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = image.getRGB(x, y);
                // 提取并归一化每个像素的 RGB 分量,存储到 tensorData 中
                tensorData[(y * width + x) * channels] = ((rgb >> 16) & 0xFF) / 255.0f;     // 红色通道
                tensorData[(y * width + x) * channels + 1] = ((rgb >> 8) & 0xFF) / 255.0f;  // 绿色通道
                tensorData[(y * width + x) * channels + 2] = (rgb & 0xFF) / 255.0f;         // 蓝色通道
            }
        }

        // 创建 TensorFlow Tensor 的形状(Shape),即 [batch_size, height, width, channels]
        long[] shape = {1, height, width, channels};
        // 使用归一化后的像素数据创建 TensorFlow Tensor 对象
        return Tensor.create(shape, FloatBuffer.wrap(tensorData));
    }

    /**
     * 调整给定的 BufferedImage 到指定的宽度和高度。
     * 使用 Image.SCALE_SMOOTH 缩放算法以得到更平滑的输出图像。
     *
     * @param originalImage 要调整大小的原始 BufferedImage 对象
     * @param width         目标图像宽度
     * @param height        目标图像高度
     * @return 调整大小后的 BufferedImage 对象
     */
    private BufferedImage resizeImage(BufferedImage originalImage, int width, int height) {
        // 使用指定的宽度和高度缩放原始图像
        Image resultingImage = originalImage.getScaledInstance(width, height, Image.SCALE_SMOOTH);

        // 创建新的 BufferedImage 作为输出图像
        BufferedImage outputImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);

        // 获取输出图像的绘图上下文
        Graphics2D g2d = outputImage.createGraphics();

        // 在输出图像上绘制缩放后的图像
        g2d.drawImage(resultingImage, 0, 0, null);

        // 释放绘图上下文资源
        g2d.dispose();

        // 返回调整大小后的 BufferedImage 对象
        return outputImage;
    }


    /**
     * 判断文件是否不安全。
     * <p>
     * 0 : Drawing -> safe for work drawings (including anime)      安全
     * 1 : Hentai -> hentai and pornographic drawings               危险
     * 2 : Neutral -> safe for work neutral images                  安全
     * 3 : Porn -> pornographic images, sexual acts                 危险
     * 4 : Sexy -> sexually explicit images, not pornography        危险
     *
     * @param output    模型输出结果的浮点数数组,形状为 [1, 5]
     * @param threshold 判定不安全的阈值
     * @return 如果文件被认为是不安全的,返回 true;否则返回 false
     */
    private boolean isUnsafe(float[][] output, float threshold) {
        // 模型输出的类别索引:
        final int DRAWING = 0;
        final int HENTAI = 1;
        final int NEUTRAL = 2;
        final int PORN = 3;
        final int SEXY = 4;

        // 安全类别的索引集合
        Set<Integer> safeCategories = Set.of(DRAWING, NEUTRAL);

        // 遍历模型输出结果
        for (int i = 0; i < output[0].length; i++) {
            // 如果当前类别不是安全类别,并且其概率超过阈值,则判定为不安全
            if (!safeCategories.contains(i) && output[0][i] > threshold) {
                return true;
            }
        }

        // 如果所有类别的概率都低于阈值,或都是安全类别,则判定为安全
        return false;
    }
}

4. 提供对应接口

其中使用工具类的 getNsfwPredictions 方法, 他会返回图像的预测结果其中包含Drawing、Hentai、Neutral、Porn、Sexy的对应概率

@RestController
@RequestMapping("nsfw")
@RequiredArgsConstructor
public class NsfwController {

    private final NSFWAnalyzerUtils nsfwAnalyzerUtils;

    @Operation(summary = "图片检测")
    @PreventDuplicateSubmit
    @PostMapping("/check")
    public Result<Map<String, String>> nsfwCheck(MultipartFile file) {
        try {
            return Result.success(nsfwAnalyzerUtils.getNsfwPredictions(file));
        } catch (Exception e) {
            throw new ServiceException(ResultCode.FILE_ANALYZER_ERROR);
        }
    }

}

5. 测试

  • 当我们上传一张正常的图片

SpringBoot集成TensorFlow对图片内容进行安全检测 : 全流程指南在这篇指南中,我们将深入探讨如何使用S

  • 当我们上传一张好看的图片(具体内容靠大家想象)

SpringBoot集成TensorFlow对图片内容进行安全检测 : 全流程指南在这篇指南中,我们将深入探讨如何使用S

注意 : 判断结果与模型有关

6. 解决部署问题

在开发阶段,项目可以直接访问 resources 目录下的文件,但在打包为 JAR 后,资源文件会嵌入到 JAR 中,无法作为文件系统中的路径直接访问。例如,以下代码在打包为 JAR 后会导致访问失败:

String modelAbsolutePath = new ClassPathResource("tensorflow/saved_model/nsfw").getFile().getAbsolutePath();
  • 解决方案:

    • 将项目部署在服务器的指定位置
    • 生成临时文件后再加载模型

我采用了生成临时文件的方法,即将 resources 目录下的字节流文件复制到临时目录,并返回临时目录的路径以加载模型。代码逻辑很简单:在 Temp 目录下创建文件,将资源目录中的文件拷贝到临时目录,最后返回临时文件的路径用于模型加载。

public static String getModelPath(String classPathResource) throws IOException {
    File tempModelDir = Files.createTempDirectory(TEMP_TENSORFLOW_MODEL_PATH).toFile();
    Path tempModelDirPath = tempModelDir.toPath();

    copyTempModel("", classPathResource, tempModelDirPath);
    copyTempModel("assets", classPathResource, tempModelDirPath);
    copyTempModel("variables", classPathResource, tempModelDirPath);

    log.info("classPathResource: {} , ===> 临时文件存储在 {} ", classPathResource, tempModelDir.getAbsolutePath());
    return tempModelDir.getAbsolutePath();
}

更详细的代码可以参考 TensorFlowUtil

四、源码

源码地址 | 👀 在线演示 | 觉得不错可以给个start

前端源码位置 : yf/ yf-vue-admin / src / views / demo / nsfw

后端源码位置 :

yf/ yf-boot-admin / yf-integration / yf-file

yf/ yf-boot-admin / yf-shared / src / main / java / com / yf / utils

注意事项 :

    1. 平台一人一号,账号可以通过邮箱、第三方平台自动注册。用户名密码方式登录请联系管理员手动添加、手机号不可用。(敏感数据以做信息脱敏)
    1. 在线聊天功能(消息已做脏词过滤,群发、系统、AI消息不会被平台记录)
    1. 欢迎大家提出意见,欢迎畅聊与项目相关问题
转载自:https://juejin.cn/post/7408062004844625960
评论
请登录