Замена объектов на изображении с использованием Stable Diffusion, SAM и Grounding DINO

Современные технологии компьютерного зрения и генерации изображений позволяют обычному пользователю (немного разбирающемуся как использовать Google colab) заменять части объектов на изображении. Здесь будет показано, как заменить автомобиль на медведя.

Используемые технологии:

  • Stable Diffusion - модели и библиотека, создающая изображения по текстовым описаниям;
  • Segment Anything Model (SAM) - модель, позволяющая очень точно сегментировать практически любой объект;
  • Grounding DINO - модель, позволяющая по текстовому описанию детектировать объект (естественно вместо нее можно использовать например Yolo).
Установка и инициализация компонент

Будем использовать Google colab, поэтому предполагаем, что там на машине установлено уже стандартное для colab окружение. Загружаем описанные выше компоненты и инсталлируем:

!pip -q install diffusers transformers scipy segment_anything
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd GroundingDINO
!pip -q install -e .

Импортируем нужные модули:

# ----SAM
from segment_anything import SamPredictor, sam_model_registry
# ----Stable Diffusion
from diffusers import StableDiffusionInpaintPipeline
# ----GroundingDINO
from groundingdino.util.inference import load_model, load_image, predict, annotate
from GroundingDINO.groundingdino.util import box_ops
# ----Extra Libraries
from PIL import Image
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np

Модели SAM и Grounding DINO нуждаются в ручной загрузке:

!wget https://huggingface.co/spaces/abhishek/StableSAM/resolve/main/sam_vit_h_4b8939.pth
!wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth

Загружаем изображение автомобиля (конечно оно может быть другим):

!wget https://www.avtomotoclub.ru/upload/medialibrary/d6e/d6ecdded7f4d6426fabb8241f73c1ebf.jpg

Указываем устройство и загружаем в память изображение:

device = "cuda"
img_path = 'd6ecdded7f4d6426fabb8241f73c1ebf.jpg'
src, img = load_image(img_path)

Инициализируем SAM:

model_type = "vit_h"
sam_checkpoint = "sam_vit_h_4b8939.pth"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

Далее скачиваем модели для Stable Diffusion

pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,).to(device)
2023-08-12_15-01-10

Инициализируем Grounding DINO и базовые настройки для детектирования объекта:

groundingdino_model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "groundingdino_swint_ogc.pth")
# ---- prompt is for identifying target object in the image
TEXT_PROMPT = "car"
BOX_TRESHOLD = 0.3
TEXT_TRESHOLD = 0.25
 Детектирование объекта

Детектирование осуществляется с помощью функции predict:

boxes, logits, phrases = predict(
    model=groundingdino_model,
    image=img,
    caption=TEXT_PROMPT,
    box_threshold=BOX_TRESHOLD,
    text_threshold=TEXT_TRESHOLD
)
img_annnotated = annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)[...,::-1]

Аннотированное изображение нужно для того, чтобы plt вместо нас нарисовал объекты на изображении.

Выводим изображение :

fig, axes = plt.subplots(1, 2, figsize=(30, 20))
plt.title("Annotated Image with Text", fontsize=30)
axes[0].imshow(src)
axes[0].axis('off')
axes[1].imshow(img_annnotated)
axes[1].axis('off')

plt.show()
2023-08-12_15-08-00
 Находим маску с помощью SAM

Служебная функция:

def show_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

 

#Преобразуем изображение в нужный формат
predictor.set_image(src)
H, W, _ = src.shape
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
new_boxes = predictor.transform.apply_boxes_torch(boxes_xyxy, src.shape[:2]).to(device)

Находим маску:

masks, _, _ = predictor.predict_torch(
            point_coords = None,
            point_labels = None,
            boxes = new_boxes,
            multimask_output = False,
        )

Аннотируем маску для plt:

img_annotated_mask = Image.fromarray(show_mask(masks[0][0].cpu(), img_annnotated))
original_img = Image.fromarray(src).resize((512, 512))
img_annotated = Image.fromarray(img_annnotated)
only_mask = Image.fromarray(masks[0][0].cpu().numpy()).resize((512, 512))

Выводим  маску:

fig, axes = plt.subplots(1, 2, figsize=(30, 20))
plt.title("Segmented Object with Mask", fontsize=30)
axes[0].imshow(img_annnotated)
axes[0].axis('off')
axes[1].imshow(img_annotated_mask)
axes[1].axis('off')
plt.show()
2023-08-12_15-14-05

Только маску можно посмотреть:

fig, axes = plt.subplots(1, 2, figsize=(30, 20))
plt.title("Only Mask", fontsize=30)
axes[0].imshow(img_annotated_mask)
axes[0].axis('off')
axes[1].imshow(Image.fromarray(masks[0][0].cpu().numpy()))
axes[1].axis('off')

plt.show()
2023-08-12_15-15-09
 Замена объекта

Заменяем автомобиль на медведя:

prompt = "Bear" #replace with this object
edited = pipe(prompt=prompt, image=original_img, mask_image=only_mask).images[0]

 

И выводим изображения на экран (как вы поняли уже,  я везде показывал второе изображение, т.к. всегда рисовались 2 - исходное и резульат):

fig, axes = plt.subplots(1, 2, figsize=(30, 15))

axes[0].imshow(src)
axes[0].axis('off')
axes[0].set_title('Before', fontdict={'fontsize': 50})
axes[0].set_aspect('auto')
axes[1].imshow(edited)
axes[1].axis('off')
axes[1].set_title('After', fontdict={'fontsize': 50})
axes[1].set_aspect('auto')
plt.show()
2023-08-12_15-18-30

Вы можете поэкспериментировать с colab notebook:

https://colab.research.google.com/drive/1ViJ-jCPVnPZUcdoD6iW9O9p3udifzYmR?usp=sharing&ref=blog.roboflow.com

Только поправьте загрузку моделей, как указано выше.