探索Canvas系列:结合Transformers.js实现智能图片处理

介绍

我目前维护着一款功能强大的开源创意画板,该画板集成了大量有趣的画笔和辅助绘图功能,让用户体验到全新的绘图效果,无论在手机端还是PC端,都能享受到更好的交互体验和效果展示。

本文我将详细讲解如何结合Transformers.js实现背景去除和图像标记分割,结果如下

image-1

链接:https://songlh.top/paint-board/

Github: https://github.com/LHRUN/paint-board 欢迎 Star ⭐️

Transformers.js

Transformers.js 是一个基于 Hugging Face 的 Transformers 的强大 JavaScript 库,可以直接在浏览器中运行,而无需依赖服务器端计算。这意味着您可以在本地运行模型,从而提高效率并降低部署和维护成本。

目前Transformers.js已经在Hugging Face上提供了1000+个模型,覆盖各个领域,可以满足你的大部分需求,例如图像处理,文本生成,翻译,情感分析等任务处理,你都可以通过Transformers.js轻松实现。搜索模型如下。

Image-2

Transformers.js 目前主版本已经更新至 V3,增加了不少精彩功能,详情:Transformers.js v3:WebGPU 支持、新模型和任务等等……

我在本篇文章中添加的两个功能都使用了 V3 才有的 WebGPU 支持,处理速度有了很大提升,解析速度可以达到毫秒级。不过需要注意的是,支持 WebGPU 的浏览器并不多,建议使用 Google 最新版本进行访问。

功能1:去除背景

为了删除背景,我使用了 Xenova/modnet 模型,它看起来像这样

Image-3

处理逻辑可以分为三步

  • 初始化状态,并加载模型和处理器。
  • 界面的显示,这个是根据你自己的设计,不是我的。
  • 展示一下效果,这个是你自己设计的,不是我的,现在比较流行用边框线来动态展示去掉背景前后的对比效果。
  • 代码逻辑如下,`React + TS`,具体看我项目的源码,源码位于src/components/boardOperation/uploadImage/index.tsx

    import { useState, FC, useRef, useEffect, useMemo } from 'react'
    import {
      env,
      AutoModel,
      AutoProcessor,
      RawImage,
      PreTrainedModel,
      Processor
    } from '@huggingface/transformers'
    
    const REMOVE_BACKGROUND_STATUS = {
      LOADING: 0,
      NO_SUPPORT_WEBGPU: 1,
      LOAD_ERROR: 2,
      LOAD_SUCCESS: 3,
      PROCESSING: 4,
      PROCESSING_SUCCESS: 5
    }
    
    type RemoveBackgroundStatusType =
      (typeof REMOVE_BACKGROUND_STATUS)[keyof typeof REMOVE_BACKGROUND_STATUS]
    
    const UploadImage: FC<{ url: string }> = ({ url }) => {
      const [removeBackgroundStatus, setRemoveBackgroundStatus] =
        useState()
      const [processedImage, setProcessedImage] = useState('')
    
      const modelRef = useRef()
      const processorRef = useRef()
    
      const removeBackgroundBtnTip = useMemo(() => {
        switch (removeBackgroundStatus) {
          case REMOVE_BACKGROUND_STATUS.LOADING:
            return 'Remove background function loading'
          case REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU:
            return 'WebGPU is not supported in this browser, to use the remove background function, please use the latest version of Google Chrome'
          case REMOVE_BACKGROUND_STATUS.LOAD_ERROR:
            return 'Remove background function failed to load'
          case REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS:
            return 'Remove background function loaded successfully'
          case REMOVE_BACKGROUND_STATUS.PROCESSING:
            return 'Remove Background Processing'
          case REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS:
            return 'Remove Background Processing Success'
          default:
            return ''
        }
      }, [removeBackgroundStatus])
    
      useEffect(() => {
        ;(async () => {
          try {
            if (removeBackgroundStatus === REMOVE_BACKGROUND_STATUS.LOADING) {
              return
            }
            setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOADING)
    
            // Checking WebGPU Support
            if (!navigator?.gpu) {
              setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU)
              return
            }
            const model_id = 'Xenova/modnet'
            if (env.backends.onnx.wasm) {
              env.backends.onnx.wasm.proxy = false
            }
    
            // Load model and processor
            modelRef.current ??= await AutoModel.from_pretrained(model_id, {
              device: 'webgpu'
            })
            processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
            setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS)
          } catch (err) {
            console.log('err', err)
            setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_ERROR)
          }
        })()
      }, [])
    
      const processImages = async () => {
        const model = modelRef.current
        const processor = processorRef.current
    
        if (!model || !processor) {
          return
        }
    
        setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING)
    
        // load image
        const img = await RawImage.fromURL(url)
    
        // Pre-processed image
        const { pixel_values } = await processor(img)
    
        // Generate image mask
        const { output } = await model({ input: pixel_values })
        const maskData = (
          await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
            img.width,
            img.height
          )
        ).data
    
        // Create a new canvas
        const canvas = document.createElement('canvas')
        canvas.width = img.width
        canvas.height = img.height
        const ctx = canvas.getContext('2d') as CanvasRenderingContext2D
    
        // Draw the original image
        ctx.drawImage(img.toCanvas(), 0, 0)
    
        // Updating the mask area
        const pixelData = ctx.getImageData(0, 0, img.width, img.height)
        for (let i = 0; i < maskData.length; ++i) {
          pixelData.data[4 * i + 3] = maskData[i]
        }
        ctx.putImageData(pixelData, 0, 0)
    
        // Save new image
        setProcessedImage(canvas.toDataURL('image/png'))
        setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS)
      }
    
      return (
        
    {removeBackgroundBtnTip}
    {processedImage && ( )}
    ) } export default UploadImage

    功能2:图像标记分割

    使用Xenova/slimsam-77-uniform模型实现图片标记分割,效果如下,图片加载完成后你可以点击图片,根据你点击的坐标生成分割。

    Image-4

    处理逻辑可以分为五步

  • 初始化状态,并加载模型和处理器
  • 获取图片并加载,然后保存图片加载数据和嵌入数据。
  • 监听图片点击事件,记录点击数据,分为正标记和负标记,每次点击之后根据点击数据解码生成mask数据,然后根据mask数据绘制分割效果。
  • 界面展示,这个要你自己设计随意发挥,不以我为准
  • 点击保存图片,根据mask像素数据,匹配原始图片数据,然后通过canvas绘制导出
  • 代码逻辑如下,`React + TS`,具体看我项目的源码,源码位于src/components/boardOperation/uploadImage/imageSegmentation.tsx

    import { useState, useRef, useEffect, useMemo, MouseEvent, FC } from 'react'
    import {
      SamModel,
      AutoProcessor,
      RawImage,
      PreTrainedModel,
      Processor,
      Tensor,
      SamImageProcessorResult
    } from '@huggingface/transformers'
    
    import LoadingIcon from '@/components/icons/loading.svg?react'
    import PositiveIcon from '@/components/icons/boardOperation/image-segmentation-positive.svg?react'
    import NegativeIcon from '@/components/icons/boardOperation/image-segmentation-negative.svg?react'
    
    interface MarkPoint {
      position: number[]
      label: number
    }
    
    const SEGMENTATION_STATUS = {
      LOADING: 0,
      NO_SUPPORT_WEBGPU: 1,
      LOAD_ERROR: 2,
      LOAD_SUCCESS: 3,
      PROCESSING: 4,
      PROCESSING_SUCCESS: 5
    }
    
    type SegmentationStatusType =
      (typeof SEGMENTATION_STATUS)[keyof typeof SEGMENTATION_STATUS]
    
    const ImageSegmentation: FC<{ url: string }> = ({ url }) => {
      const [markPoints, setMarkPoints] = useState([])
      const [segmentationStatus, setSegmentationStatus] =
        useState()
      const [pointStatus, setPointStatus] = useState(true)
    
      const maskCanvasRef = useRef(null) // Segmentation mask
      const modelRef = useRef() // model
      const processorRef = useRef() // processor
      const imageInputRef = useRef() // original image
      const imageProcessed = useRef() // Processed image
      const imageEmbeddings = useRef() // Embedding data
    
      const segmentationTip = useMemo(() => {
        switch (segmentationStatus) {
          case SEGMENTATION_STATUS.LOADING:
            return 'Image Segmentation function Loading'
          case SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU:
            return 'WebGPU is not supported in this browser, to use the image segmentation function, please use the latest version of Google Chrome.'
          case SEGMENTATION_STATUS.LOAD_ERROR:
            return 'Image Segmentation function failed to load'
          case SEGMENTATION_STATUS.LOAD_SUCCESS:
            return 'Image Segmentation function loaded successfully'
          case SEGMENTATION_STATUS.PROCESSING:
            return 'Image Processing...'
          case SEGMENTATION_STATUS.PROCESSING_SUCCESS:
            return 'The image has been processed successfully, you can click on the image to mark it, the green mask area is the segmentation area.'
          default:
            return ''
        }
      }, [segmentationStatus])
    
      // 1. load model and processor
      useEffect(() => {
        ;(async () => {
          try {
            if (segmentationStatus === SEGMENTATION_STATUS.LOADING) {
              return
            }
    
            setSegmentationStatus(SEGMENTATION_STATUS.LOADING)
            if (!navigator?.gpu) {
              setSegmentationStatus(SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU)
              return
            }
            const model_id = 'Xenova/slimsam-77-uniform'
            modelRef.current ??= await SamModel.from_pretrained(model_id, {
              dtype: 'fp16', // or "fp32"
              device: 'webgpu'
            })
            processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
    
            setSegmentationStatus(SEGMENTATION_STATUS.LOAD_SUCCESS)
          } catch (err) {
            console.log('err', err)
            setSegmentationStatus(SEGMENTATION_STATUS.LOAD_ERROR)
          }
        })()
      }, [])
    
      // 2. process image
      useEffect(() => {
        ;(async () => {
          try {
            if (
              !modelRef.current ||
              !processorRef.current ||
              !url ||
              segmentationStatus === SEGMENTATION_STATUS.PROCESSING
            ) {
              return
            }
            setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING)
            clearPoints()
    
            imageInputRef.current = await RawImage.fromURL(url)
            imageProcessed.current = await processorRef.current(
              imageInputRef.current
            )
            imageEmbeddings.current = await (
              modelRef.current as any
            ).get_image_embeddings(imageProcessed.current)
    
            setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING_SUCCESS)
          } catch (err) {
            console.log('err', err)
          }
        })()
      }, [url, modelRef.current, processorRef.current])
    
      // Updating the mask effect
      function updateMaskOverlay(mask: RawImage, scores: Float32Array) {
        const maskCanvas = maskCanvasRef.current
        if (!maskCanvas) {
          return
        }
        const maskContext = maskCanvas.getContext('2d') as CanvasRenderingContext2D
    
        // Update canvas dimensions (if different)
        if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
          maskCanvas.width = mask.width
          maskCanvas.height = mask.height
        }
    
        // Allocate buffer for pixel data
        const imageData = maskContext.createImageData(
          maskCanvas.width,
          maskCanvas.height
        )
    
        // Select best mask
        const numMasks = scores.length // 3
        let bestIndex = 0
        for (let i = 1; i < numMasks; ++i) {
          if (scores[i] > scores[bestIndex]) {
            bestIndex = i
          }
        }
    
        // Fill mask with colour
        const pixelData = imageData.data
        for (let i = 0; i < pixelData.length; ++i) {
          if (mask.data[numMasks * i + bestIndex] === 1) {
            const offset = 4 * i
            pixelData[offset] = 101 // r
            pixelData[offset + 1] = 204 // g
            pixelData[offset + 2] = 138 // b
            pixelData[offset + 3] = 255 // a
          }
        }
    
        // Draw image data to context
        maskContext.putImageData(imageData, 0, 0)
      }
    
      // 3. Decoding based on click data
      const decode = async (markPoints: MarkPoint[]) => {
        if (
          !modelRef.current ||
          !imageEmbeddings.current ||
          !processorRef.current ||
          !imageProcessed.current
        ) {
          return
        }
    
        // No click on the data directly clears the segmentation effect
        if (!markPoints.length && maskCanvasRef.current) {
          const maskContext = maskCanvasRef.current.getContext(
            '2d'
          ) as CanvasRenderingContext2D
          maskContext.clearRect(
            0,
            0,
            maskCanvasRef.current.width,
            maskCanvasRef.current.height
          )
          return
        }
    
        // Prepare inputs for decoding
        const reshaped = imageProcessed.current.reshaped_input_sizes[0]
        const points = markPoints
          .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
          .flat(Infinity)
        const labels = markPoints.map((x) => BigInt(x.label)).flat(Infinity)
    
        const num_points = markPoints.length
        const input_points = new Tensor('float32', points, [1, 1, num_points, 2])
        const input_labels = new Tensor('int64', labels, [1, 1, num_points])
    
        // Generate the mask
        const { pred_masks, iou_scores } = await modelRef.current({
          ...imageEmbeddings.current,
          input_points,
          input_labels
        })
    
        // Post-process the mask
        const masks = await (processorRef.current as any).post_process_masks(
          pred_masks,
          imageProcessed.current.original_sizes,
          imageProcessed.current.reshaped_input_sizes
        )
    
        updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data)
      }
    
      const clamp = (x: number, min = 0, max = 1) => {
        return Math.max(Math.min(x, max), min)
      }
    
      const clickImage = (e: MouseEvent) => {
        if (segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS) {
          return
        }
    
        const { clientX, clientY, currentTarget } = e
        const { left, top } = currentTarget.getBoundingClientRect()
    
        const x = clamp(
          (clientX - left + currentTarget.scrollLeft) / currentTarget.scrollWidth
        )
        const y = clamp(
          (clientY - top + currentTarget.scrollTop) / currentTarget.scrollHeight
        )
    
        const existingPointIndex = markPoints.findIndex(
          (point) =>
            Math.abs(point.position[0] - x) < 0.01 &&
            Math.abs(point.position[1] - y) < 0.01 &&
            point.label === (pointStatus ? 1 : 0)
        )
    
        const newPoints = [...markPoints]
        if (existingPointIndex !== -1) {
          // If there is a marker in the currently clicked area, it is deleted.
          newPoints.splice(existingPointIndex, 1)
        } else {
          newPoints.push({
            position: [x, y],
            label: pointStatus ? 1 : 0
          })
        }
    
        setMarkPoints(newPoints)
        decode(newPoints)
      }
    
      const clearPoints = () => {
        setMarkPoints([])
        decode([])
      }
    
      return (
        
    {segmentationTip}
    {segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS && (
    )}
    {markPoints.map((point, index) => { switch (point.label) { case 1: return ( ) case 0: return ( ) default: return null } })}
    ) } export default ImageSegmentation

    结论

    谢谢阅读。以上就是本文的全部内容,希望本文对大家有所帮助,欢迎点赞收藏。如有疑问,欢迎在评论区讨论!