grabcut.py
1 | # -*- coding: utf-8 -*- |
---|---|
2 | |
3 | import cv2 |
4 | import scipy |
5 | import scipy.ndimage |
6 | from PIL import Image |
7 | |
8 | img = scipy.array(Image.open('image.png'))
|
9 | mask = scipy.array(Image.open('mask.png'))
|
10 | assert img.shape == mask.shape
|
11 | imgmask = scipy.zeros((mask.shape[0],mask.shape[1]),scipy.uint8) |
12 | imgmask[mask[...,0] == 0] = cv2.GC_BGD |
13 | imgmask[mask[...,0] == 64] = cv2.GC_PR_BGD |
14 | imgmask[mask[...,0] == 128] = cv2.GC_PR_FGD |
15 | imgmask[mask[...,0] == 255] = cv2.GC_FGD |
16 | print mask.shape
|
17 | print imgmask.shape
|
18 | bgdModel = scipy.zeros((1,5*13),scipy.float64) |
19 | fgdModel = scipy.zeros((1,5*13),scipy.float64) |
20 | cv2.grabCut( |
21 | img, |
22 | imgmask, |
23 | None,
|
24 | bgdModel, |
25 | fgdModel, |
26 | 1,
|
27 | cv2.GC_INIT_WITH_MASK, |
28 | ) |
29 | mask = (imgmask == cv2.GC_FGD) | (imgmask == cv2.GC_PR_FGD) |
30 | masked = scipy.zeros_like(img) |
31 | masked[mask] = img[mask] |
32 | cv2.imshow('test',masked)
|
33 | cv2.waitKey(0) |