探索Canvas系列:结合Transformers.js实现智能图片处理
介绍
我目前维护着一款功能强大的开源创意画板,该画板集成了大量有趣的画笔和辅助绘图功能,让用户体验到全新的绘图效果,无论在手机端还是PC端,都能享受到更好的交互体验和效果展示。
本文我将详细讲解如何结合Transformers.js实现背景去除和图像标记分割,结果如下

链接:https://songlh.top/paint-board/
Github: https://github.com/LHRUN/paint-board 欢迎 Star ⭐️
Transformers.js
Transformers.js 是一个基于 Hugging Face 的 Transformers 的强大 JavaScript 库,可以直接在浏览器中运行,而无需依赖服务器端计算。这意味着您可以在本地运行模型,从而提高效率并降低部署和维护成本。
目前Transformers.js已经在Hugging Face上提供了1000+个模型,覆盖各个领域,可以满足你的大部分需求,例如图像处理,文本生成,翻译,情感分析等任务处理,你都可以通过Transformers.js轻松实现。搜索模型如下。

Transformers.js 目前主版本已经更新至 V3,增加了不少精彩功能,详情:Transformers.js v3:WebGPU 支持、新模型和任务等等……
我在本篇文章中添加的两个功能都使用了 V3 才有的 WebGPU 支持,处理速度有了很大提升,解析速度可以达到毫秒级。不过需要注意的是,支持 WebGPU 的浏览器并不多,建议使用 Google 最新版本进行访问。
功能1:去除背景
为了删除背景,我使用了 Xenova/modnet 模型,它看起来像这样

处理逻辑可以分为三步
代码逻辑如下,`React + TS`,具体看我项目的源码,源码位于src/components/boardOperation/uploadImage/index.tsx
import { useState, FC, useRef, useEffect, useMemo } from 'react' import { env, AutoModel, AutoProcessor, RawImage, PreTrainedModel, Processor } from '@huggingface/transformers' const REMOVE_BACKGROUND_STATUS = { LOADING: 0, NO_SUPPORT_WEBGPU: 1, LOAD_ERROR: 2, LOAD_SUCCESS: 3, PROCESSING: 4, PROCESSING_SUCCESS: 5 } type RemoveBackgroundStatusType = (typeof REMOVE_BACKGROUND_STATUS)[keyof typeof REMOVE_BACKGROUND_STATUS] const UploadImage: FC<{ url: string }> = ({ url }) => { const [removeBackgroundStatus, setRemoveBackgroundStatus] = useState() const [processedImage, setProcessedImage] = useState('') const modelRef = useRef () const processorRef = useRef () const removeBackgroundBtnTip = useMemo(() => { switch (removeBackgroundStatus) { case REMOVE_BACKGROUND_STATUS.LOADING: return 'Remove background function loading' case REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU: return 'WebGPU is not supported in this browser, to use the remove background function, please use the latest version of Google Chrome' case REMOVE_BACKGROUND_STATUS.LOAD_ERROR: return 'Remove background function failed to load' case REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS: return 'Remove background function loaded successfully' case REMOVE_BACKGROUND_STATUS.PROCESSING: return 'Remove Background Processing' case REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS: return 'Remove Background Processing Success' default: return '' } }, [removeBackgroundStatus]) useEffect(() => { ;(async () => { try { if (removeBackgroundStatus === REMOVE_BACKGROUND_STATUS.LOADING) { return } setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOADING) // Checking WebGPU Support if (!navigator?.gpu) { setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU) return } const model_id = 'Xenova/modnet' if (env.backends.onnx.wasm) { env.backends.onnx.wasm.proxy = false } // Load model and processor modelRef.current ??= await AutoModel.from_pretrained(model_id, { device: 'webgpu' }) processorRef.current ??= await AutoProcessor.from_pretrained(model_id) setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS) } catch (err) { console.log('err', err) setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_ERROR) } })() }, []) const processImages = async () => { const model = modelRef.current const processor = processorRef.current if (!model || !processor) { return } setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING) // load image const img = await RawImage.fromURL(url) // Pre-processed image const { pixel_values } = await processor(img) // Generate image mask const { output } = await model({ input: pixel_values }) const maskData = ( await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize( img.width, img.height ) ).data // Create a new canvas const canvas = document.createElement('canvas') canvas.width = img.width canvas.height = img.height const ctx = canvas.getContext('2d') as CanvasRenderingContext2D // Draw the original image ctx.drawImage(img.toCanvas(), 0, 0) // Updating the mask area const pixelData = ctx.getImageData(0, 0, img.width, img.height) for (let i = 0; i < maskData.length; ++i) { pixelData.data[4 * i + 3] = maskData[i] } ctx.putImageData(pixelData, 0, 0) // Save new image setProcessedImage(canvas.toDataURL('image/png')) setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS) } return ( ) } export default UploadImage{removeBackgroundBtnTip}{processedImage && (
)}
功能2:图像标记分割
使用Xenova/slimsam-77-uniform模型实现图片标记分割,效果如下,图片加载完成后你可以点击图片,根据你点击的坐标生成分割。

处理逻辑可以分为五步
代码逻辑如下,`React + TS`,具体看我项目的源码,源码位于src/components/boardOperation/uploadImage/imageSegmentation.tsx
import { useState, useRef, useEffect, useMemo, MouseEvent, FC } from 'react' import { SamModel, AutoProcessor, RawImage, PreTrainedModel, Processor, Tensor, SamImageProcessorResult } from '@huggingface/transformers' import LoadingIcon from '@/components/icons/loading.svg?react' import PositiveIcon from '@/components/icons/boardOperation/image-segmentation-positive.svg?react' import NegativeIcon from '@/components/icons/boardOperation/image-segmentation-negative.svg?react' interface MarkPoint { position: number[] label: number } const SEGMENTATION_STATUS = { LOADING: 0, NO_SUPPORT_WEBGPU: 1, LOAD_ERROR: 2, LOAD_SUCCESS: 3, PROCESSING: 4, PROCESSING_SUCCESS: 5 } type SegmentationStatusType = (typeof SEGMENTATION_STATUS)[keyof typeof SEGMENTATION_STATUS] const ImageSegmentation: FC<{ url: string }> = ({ url }) => { const [markPoints, setMarkPoints] = useState([]) const [segmentationStatus, setSegmentationStatus] = useState () const [pointStatus, setPointStatus] = useState (true) const maskCanvasRef = useRef (null) // Segmentation mask const modelRef = useRef () // model const processorRef = useRef () // processor const imageInputRef = useRef () // original image const imageProcessed = useRef () // Processed image const imageEmbeddings = useRef () // Embedding data const segmentationTip = useMemo(() => { switch (segmentationStatus) { case SEGMENTATION_STATUS.LOADING: return 'Image Segmentation function Loading' case SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU: return 'WebGPU is not supported in this browser, to use the image segmentation function, please use the latest version of Google Chrome.' case SEGMENTATION_STATUS.LOAD_ERROR: return 'Image Segmentation function failed to load' case SEGMENTATION_STATUS.LOAD_SUCCESS: return 'Image Segmentation function loaded successfully' case SEGMENTATION_STATUS.PROCESSING: return 'Image Processing...' case SEGMENTATION_STATUS.PROCESSING_SUCCESS: return 'The image has been processed successfully, you can click on the image to mark it, the green mask area is the segmentation area.' default: return '' } }, [segmentationStatus]) // 1. load model and processor useEffect(() => { ;(async () => { try { if (segmentationStatus === SEGMENTATION_STATUS.LOADING) { return } setSegmentationStatus(SEGMENTATION_STATUS.LOADING) if (!navigator?.gpu) { setSegmentationStatus(SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU) return } const model_id = 'Xenova/slimsam-77-uniform' modelRef.current ??= await SamModel.from_pretrained(model_id, { dtype: 'fp16', // or "fp32" device: 'webgpu' }) processorRef.current ??= await AutoProcessor.from_pretrained(model_id) setSegmentationStatus(SEGMENTATION_STATUS.LOAD_SUCCESS) } catch (err) { console.log('err', err) setSegmentationStatus(SEGMENTATION_STATUS.LOAD_ERROR) } })() }, []) // 2. process image useEffect(() => { ;(async () => { try { if ( !modelRef.current || !processorRef.current || !url || segmentationStatus === SEGMENTATION_STATUS.PROCESSING ) { return } setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING) clearPoints() imageInputRef.current = await RawImage.fromURL(url) imageProcessed.current = await processorRef.current( imageInputRef.current ) imageEmbeddings.current = await ( modelRef.current as any ).get_image_embeddings(imageProcessed.current) setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING_SUCCESS) } catch (err) { console.log('err', err) } })() }, [url, modelRef.current, processorRef.current]) // Updating the mask effect function updateMaskOverlay(mask: RawImage, scores: Float32Array) { const maskCanvas = maskCanvasRef.current if (!maskCanvas) { return } const maskContext = maskCanvas.getContext('2d') as CanvasRenderingContext2D // Update canvas dimensions (if different) if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { maskCanvas.width = mask.width maskCanvas.height = mask.height } // Allocate buffer for pixel data const imageData = maskContext.createImageData( maskCanvas.width, maskCanvas.height ) // Select best mask const numMasks = scores.length // 3 let bestIndex = 0 for (let i = 1; i < numMasks; ++i) { if (scores[i] > scores[bestIndex]) { bestIndex = i } } // Fill mask with colour const pixelData = imageData.data for (let i = 0; i < pixelData.length; ++i) { if (mask.data[numMasks * i + bestIndex] === 1) { const offset = 4 * i pixelData[offset] = 101 // r pixelData[offset + 1] = 204 // g pixelData[offset + 2] = 138 // b pixelData[offset + 3] = 255 // a } } // Draw image data to context maskContext.putImageData(imageData, 0, 0) } // 3. Decoding based on click data const decode = async (markPoints: MarkPoint[]) => { if ( !modelRef.current || !imageEmbeddings.current || !processorRef.current || !imageProcessed.current ) { return } // No click on the data directly clears the segmentation effect if (!markPoints.length && maskCanvasRef.current) { const maskContext = maskCanvasRef.current.getContext( '2d' ) as CanvasRenderingContext2D maskContext.clearRect( 0, 0, maskCanvasRef.current.width, maskCanvasRef.current.height ) return } // Prepare inputs for decoding const reshaped = imageProcessed.current.reshaped_input_sizes[0] const points = markPoints .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) .flat(Infinity) const labels = markPoints.map((x) => BigInt(x.label)).flat(Infinity) const num_points = markPoints.length const input_points = new Tensor('float32', points, [1, 1, num_points, 2]) const input_labels = new Tensor('int64', labels, [1, 1, num_points]) // Generate the mask const { pred_masks, iou_scores } = await modelRef.current({ ...imageEmbeddings.current, input_points, input_labels }) // Post-process the mask const masks = await (processorRef.current as any).post_process_masks( pred_masks, imageProcessed.current.original_sizes, imageProcessed.current.reshaped_input_sizes ) updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data) } const clamp = (x: number, min = 0, max = 1) => { return Math.max(Math.min(x, max), min) } const clickImage = (e: MouseEvent) => { if (segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS) { return } const { clientX, clientY, currentTarget } = e const { left, top } = currentTarget.getBoundingClientRect() const x = clamp( (clientX - left + currentTarget.scrollLeft) / currentTarget.scrollWidth ) const y = clamp( (clientY - top + currentTarget.scrollTop) / currentTarget.scrollHeight ) const existingPointIndex = markPoints.findIndex( (point) => Math.abs(point.position[0] - x) < 0.01 && Math.abs(point.position[1] - y) < 0.01 && point.label === (pointStatus ? 1 : 0) ) const newPoints = [...markPoints] if (existingPointIndex !== -1) { // If there is a marker in the currently clicked area, it is deleted. newPoints.splice(existingPointIndex, 1) } else { newPoints.push({ position: [x, y], label: pointStatus ? 1 : 0 }) } setMarkPoints(newPoints) decode(newPoints) } const clearPoints = () => { setMarkPoints([]) decode([]) } return ( ) } export default ImageSegmentation{segmentationTip}{segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS && ()}{markPoints.map((point, index) => { switch (point.label) { case 1: return (
) case 0: return ( ) default: return null } })}
结论
谢谢阅读。以上就是本文的全部内容,希望本文对大家有所帮助,欢迎点赞收藏。如有疑问,欢迎在评论区讨论!