探索Canvas系列:结合Transformers.js实现智能图片处理
介绍
我目前维护着一款功能强大的开源创意画板,该画板集成了大量有趣的画笔和辅助绘图功能,让用户体验到全新的绘图效果,无论在手机端还是PC端,都能享受到更好的交互体验和效果展示。
本文我将详细讲解如何结合Transformers.js实现背景去除和图像标记分割,结果如下

链接:https://songlh.top/paint-board/
Github: https://github.com/LHRUN/paint-board 欢迎 Star ⭐️
Transformers.js
Transformers.js 是一个基于 Hugging Face 的 Transformers 的强大 JavaScript 库,可以直接在浏览器中运行,而无需依赖服务器端计算。这意味着您可以在本地运行模型,从而提高效率并降低部署和维护成本。
目前Transformers.js已经在Hugging Face上提供了1000+个模型,覆盖各个领域,可以满足你的大部分需求,例如图像处理,文本生成,翻译,情感分析等任务处理,你都可以通过Transformers.js轻松实现。搜索模型如下。

Transformers.js 目前主版本已经更新至 V3,增加了不少精彩功能,详情:Transformers.js v3:WebGPU 支持、新模型和任务等等……
我在本篇文章中添加的两个功能都使用了 V3 才有的 WebGPU 支持,处理速度有了很大提升,解析速度可以达到毫秒级。不过需要注意的是,支持 WebGPU 的浏览器并不多,建议使用 Google 最新版本进行访问。
功能1:去除背景
为了删除背景,我使用了 Xenova/modnet 模型,它看起来像这样

处理逻辑可以分为三步
代码逻辑如下,`React + TS`,具体看我项目的源码,源码位于src/components/boardOperation/uploadImage/index.tsx
import { useState, FC, useRef, useEffect, useMemo } from 'react'
import {
env,
AutoModel,
AutoProcessor,
RawImage,
PreTrainedModel,
Processor
} from '@huggingface/transformers'
const REMOVE_BACKGROUND_STATUS = {
LOADING: 0,
NO_SUPPORT_WEBGPU: 1,
LOAD_ERROR: 2,
LOAD_SUCCESS: 3,
PROCESSING: 4,
PROCESSING_SUCCESS: 5
}
type RemoveBackgroundStatusType =
(typeof REMOVE_BACKGROUND_STATUS)[keyof typeof REMOVE_BACKGROUND_STATUS]
const UploadImage: FC<{ url: string }> = ({ url }) => {
const [removeBackgroundStatus, setRemoveBackgroundStatus] =
useState()
const [processedImage, setProcessedImage] = useState('')
const modelRef = useRef()
const processorRef = useRef()
const removeBackgroundBtnTip = useMemo(() => {
switch (removeBackgroundStatus) {
case REMOVE_BACKGROUND_STATUS.LOADING:
return 'Remove background function loading'
case REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU:
return 'WebGPU is not supported in this browser, to use the remove background function, please use the latest version of Google Chrome'
case REMOVE_BACKGROUND_STATUS.LOAD_ERROR:
return 'Remove background function failed to load'
case REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS:
return 'Remove background function loaded successfully'
case REMOVE_BACKGROUND_STATUS.PROCESSING:
return 'Remove Background Processing'
case REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS:
return 'Remove Background Processing Success'
default:
return ''
}
}, [removeBackgroundStatus])
useEffect(() => {
;(async () => {
try {
if (removeBackgroundStatus === REMOVE_BACKGROUND_STATUS.LOADING) {
return
}
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOADING)
// Checking WebGPU Support
if (!navigator?.gpu) {
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.NO_SUPPORT_WEBGPU)
return
}
const model_id = 'Xenova/modnet'
if (env.backends.onnx.wasm) {
env.backends.onnx.wasm.proxy = false
}
// Load model and processor
modelRef.current ??= await AutoModel.from_pretrained(model_id, {
device: 'webgpu'
})
processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_SUCCESS)
} catch (err) {
console.log('err', err)
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.LOAD_ERROR)
}
})()
}, [])
const processImages = async () => {
const model = modelRef.current
const processor = processorRef.current
if (!model || !processor) {
return
}
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING)
// load image
const img = await RawImage.fromURL(url)
// Pre-processed image
const { pixel_values } = await processor(img)
// Generate image mask
const { output } = await model({ input: pixel_values })
const maskData = (
await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
img.width,
img.height
)
).data
// Create a new canvas
const canvas = document.createElement('canvas')
canvas.width = img.width
canvas.height = img.height
const ctx = canvas.getContext('2d') as CanvasRenderingContext2D
// Draw the original image
ctx.drawImage(img.toCanvas(), 0, 0)
// Updating the mask area
const pixelData = ctx.getImageData(0, 0, img.width, img.height)
for (let i = 0; i < maskData.length; ++i) {
pixelData.data[4 * i + 3] = maskData[i]
}
ctx.putImageData(pixelData, 0, 0)
// Save new image
setProcessedImage(canvas.toDataURL('image/png'))
setRemoveBackgroundStatus(REMOVE_BACKGROUND_STATUS.PROCESSING_SUCCESS)
}
return (
{removeBackgroundBtnTip}
{processedImage && (
)}
)
}
export default UploadImage 功能2:图像标记分割
使用Xenova/slimsam-77-uniform模型实现图片标记分割,效果如下,图片加载完成后你可以点击图片,根据你点击的坐标生成分割。

处理逻辑可以分为五步
代码逻辑如下,`React + TS`,具体看我项目的源码,源码位于src/components/boardOperation/uploadImage/imageSegmentation.tsx
import { useState, useRef, useEffect, useMemo, MouseEvent, FC } from 'react'
import {
SamModel,
AutoProcessor,
RawImage,
PreTrainedModel,
Processor,
Tensor,
SamImageProcessorResult
} from '@huggingface/transformers'
import LoadingIcon from '@/components/icons/loading.svg?react'
import PositiveIcon from '@/components/icons/boardOperation/image-segmentation-positive.svg?react'
import NegativeIcon from '@/components/icons/boardOperation/image-segmentation-negative.svg?react'
interface MarkPoint {
position: number[]
label: number
}
const SEGMENTATION_STATUS = {
LOADING: 0,
NO_SUPPORT_WEBGPU: 1,
LOAD_ERROR: 2,
LOAD_SUCCESS: 3,
PROCESSING: 4,
PROCESSING_SUCCESS: 5
}
type SegmentationStatusType =
(typeof SEGMENTATION_STATUS)[keyof typeof SEGMENTATION_STATUS]
const ImageSegmentation: FC<{ url: string }> = ({ url }) => {
const [markPoints, setMarkPoints] = useState([])
const [segmentationStatus, setSegmentationStatus] =
useState()
const [pointStatus, setPointStatus] = useState(true)
const maskCanvasRef = useRef(null) // Segmentation mask
const modelRef = useRef() // model
const processorRef = useRef() // processor
const imageInputRef = useRef() // original image
const imageProcessed = useRef() // Processed image
const imageEmbeddings = useRef() // Embedding data
const segmentationTip = useMemo(() => {
switch (segmentationStatus) {
case SEGMENTATION_STATUS.LOADING:
return 'Image Segmentation function Loading'
case SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU:
return 'WebGPU is not supported in this browser, to use the image segmentation function, please use the latest version of Google Chrome.'
case SEGMENTATION_STATUS.LOAD_ERROR:
return 'Image Segmentation function failed to load'
case SEGMENTATION_STATUS.LOAD_SUCCESS:
return 'Image Segmentation function loaded successfully'
case SEGMENTATION_STATUS.PROCESSING:
return 'Image Processing...'
case SEGMENTATION_STATUS.PROCESSING_SUCCESS:
return 'The image has been processed successfully, you can click on the image to mark it, the green mask area is the segmentation area.'
default:
return ''
}
}, [segmentationStatus])
// 1. load model and processor
useEffect(() => {
;(async () => {
try {
if (segmentationStatus === SEGMENTATION_STATUS.LOADING) {
return
}
setSegmentationStatus(SEGMENTATION_STATUS.LOADING)
if (!navigator?.gpu) {
setSegmentationStatus(SEGMENTATION_STATUS.NO_SUPPORT_WEBGPU)
return
}
const model_id = 'Xenova/slimsam-77-uniform'
modelRef.current ??= await SamModel.from_pretrained(model_id, {
dtype: 'fp16', // or "fp32"
device: 'webgpu'
})
processorRef.current ??= await AutoProcessor.from_pretrained(model_id)
setSegmentationStatus(SEGMENTATION_STATUS.LOAD_SUCCESS)
} catch (err) {
console.log('err', err)
setSegmentationStatus(SEGMENTATION_STATUS.LOAD_ERROR)
}
})()
}, [])
// 2. process image
useEffect(() => {
;(async () => {
try {
if (
!modelRef.current ||
!processorRef.current ||
!url ||
segmentationStatus === SEGMENTATION_STATUS.PROCESSING
) {
return
}
setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING)
clearPoints()
imageInputRef.current = await RawImage.fromURL(url)
imageProcessed.current = await processorRef.current(
imageInputRef.current
)
imageEmbeddings.current = await (
modelRef.current as any
).get_image_embeddings(imageProcessed.current)
setSegmentationStatus(SEGMENTATION_STATUS.PROCESSING_SUCCESS)
} catch (err) {
console.log('err', err)
}
})()
}, [url, modelRef.current, processorRef.current])
// Updating the mask effect
function updateMaskOverlay(mask: RawImage, scores: Float32Array) {
const maskCanvas = maskCanvasRef.current
if (!maskCanvas) {
return
}
const maskContext = maskCanvas.getContext('2d') as CanvasRenderingContext2D
// Update canvas dimensions (if different)
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
maskCanvas.width = mask.width
maskCanvas.height = mask.height
}
// Allocate buffer for pixel data
const imageData = maskContext.createImageData(
maskCanvas.width,
maskCanvas.height
)
// Select best mask
const numMasks = scores.length // 3
let bestIndex = 0
for (let i = 1; i < numMasks; ++i) {
if (scores[i] > scores[bestIndex]) {
bestIndex = i
}
}
// Fill mask with colour
const pixelData = imageData.data
for (let i = 0; i < pixelData.length; ++i) {
if (mask.data[numMasks * i + bestIndex] === 1) {
const offset = 4 * i
pixelData[offset] = 101 // r
pixelData[offset + 1] = 204 // g
pixelData[offset + 2] = 138 // b
pixelData[offset + 3] = 255 // a
}
}
// Draw image data to context
maskContext.putImageData(imageData, 0, 0)
}
// 3. Decoding based on click data
const decode = async (markPoints: MarkPoint[]) => {
if (
!modelRef.current ||
!imageEmbeddings.current ||
!processorRef.current ||
!imageProcessed.current
) {
return
}
// No click on the data directly clears the segmentation effect
if (!markPoints.length && maskCanvasRef.current) {
const maskContext = maskCanvasRef.current.getContext(
'2d'
) as CanvasRenderingContext2D
maskContext.clearRect(
0,
0,
maskCanvasRef.current.width,
maskCanvasRef.current.height
)
return
}
// Prepare inputs for decoding
const reshaped = imageProcessed.current.reshaped_input_sizes[0]
const points = markPoints
.map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
.flat(Infinity)
const labels = markPoints.map((x) => BigInt(x.label)).flat(Infinity)
const num_points = markPoints.length
const input_points = new Tensor('float32', points, [1, 1, num_points, 2])
const input_labels = new Tensor('int64', labels, [1, 1, num_points])
// Generate the mask
const { pred_masks, iou_scores } = await modelRef.current({
...imageEmbeddings.current,
input_points,
input_labels
})
// Post-process the mask
const masks = await (processorRef.current as any).post_process_masks(
pred_masks,
imageProcessed.current.original_sizes,
imageProcessed.current.reshaped_input_sizes
)
updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data)
}
const clamp = (x: number, min = 0, max = 1) => {
return Math.max(Math.min(x, max), min)
}
const clickImage = (e: MouseEvent) => {
if (segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS) {
return
}
const { clientX, clientY, currentTarget } = e
const { left, top } = currentTarget.getBoundingClientRect()
const x = clamp(
(clientX - left + currentTarget.scrollLeft) / currentTarget.scrollWidth
)
const y = clamp(
(clientY - top + currentTarget.scrollTop) / currentTarget.scrollHeight
)
const existingPointIndex = markPoints.findIndex(
(point) =>
Math.abs(point.position[0] - x) < 0.01 &&
Math.abs(point.position[1] - y) < 0.01 &&
point.label === (pointStatus ? 1 : 0)
)
const newPoints = [...markPoints]
if (existingPointIndex !== -1) {
// If there is a marker in the currently clicked area, it is deleted.
newPoints.splice(existingPointIndex, 1)
} else {
newPoints.push({
position: [x, y],
label: pointStatus ? 1 : 0
})
}
setMarkPoints(newPoints)
decode(newPoints)
}
const clearPoints = () => {
setMarkPoints([])
decode([])
}
return (
{segmentationTip}
{segmentationStatus !== SEGMENTATION_STATUS.PROCESSING_SUCCESS && (
)}
{markPoints.map((point, index) => {
switch (point.label) {
case 1:
return (
)
case 0:
return (
)
default:
return null
}
})}
)
}
export default ImageSegmentation 结论
谢谢阅读。以上就是本文的全部内容,希望本文对大家有所帮助,欢迎点赞收藏。如有疑问,欢迎在评论区讨论!