use Fooocus' facexlib

This commit is contained in:
lllyasviel 2023-11-12 03:45:29 -08:00
parent ffd5eabe08
commit b8a035dc15
17 changed files with 2728 additions and 3 deletions

View File

@ -25,7 +25,7 @@ def crop_image(img_rgb):
global faceRestoreHelper global faceRestoreHelper
if faceRestoreHelper is None: if faceRestoreHelper is None:
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from fooocus_extras.facexlib.utils.face_restoration_helper import FaceRestoreHelper
faceRestoreHelper = FaceRestoreHelper( faceRestoreHelper = FaceRestoreHelper(
upscale_factor=1, upscale_factor=1,
model_rootpath=modules.config.path_controlnet, model_rootpath=modules.config.path_controlnet,

View File

@ -0,0 +1,31 @@
import torch
from copy import deepcopy
from fooocus_extras.facexlib.utils import load_file_from_url
from .retinaface import RetinaFace
def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'retinaface_resnet50':
model = RetinaFace(network_name='resnet50', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
elif model_name == 'retinaface_mobile0.25':
model = RetinaFace(network_name='mobile0.25', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
# TODO: clean pretrained model
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
return model

View File

@ -0,0 +1,219 @@
import cv2
import numpy as np
from .matlab_cp2tform import get_similarity_transform_for_cv2
# reference facial points, a list of coordinates (x,y)
REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
[33.54930115, 92.3655014], [62.72990036, 92.20410156]]
DEFAULT_CROP_SIZE = (96, 112)
class FaceWarpException(Exception):
def __str__(self):
return 'In File {}:{}'.format(__file__, super.__str__(self))
def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
"""
Function:
----------
get reference 5 key points according to crop settings:
0. Set default crop_size:
if default_square:
crop_size = (112, 112)
else:
crop_size = (96, 112)
1. Pad the crop_size by inner_padding_factor in each side;
2. Resize crop_size into (output_size - outer_padding*2),
pad into output_size with outer_padding;
3. Output reference_5point;
Parameters:
----------
@output_size: (w, h) or None
size of aligned face image
@inner_padding_factor: (w_factor, h_factor)
padding factor for inner (w, h)
@outer_padding: (w_pad, h_pad)
each row is a pair of coordinates (x, y)
@default_square: True or False
if True:
default crop_size = (112, 112)
else:
default crop_size = (96, 112);
!!! make sure, if output_size is not None:
(output_size - outer_padding)
= some_scale * (default crop_size * (1.0 +
inner_padding_factor))
Returns:
----------
@reference_5point: 5x2 np.array
each row is a pair of transformed coordinates (x, y)
"""
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
# 0) make the inner region a square
if default_square:
size_diff = max(tmp_crop_size) - tmp_crop_size
tmp_5pts += size_diff / 2
tmp_crop_size += size_diff
if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
return tmp_5pts
if (inner_padding_factor == 0 and outer_padding == (0, 0)):
if output_size is None:
return tmp_5pts
else:
raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
# check output size
if not (0 <= inner_padding_factor <= 1.0):
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
output_size = tmp_crop_size * \
(1 + inner_padding_factor * 2).astype(np.int32)
output_size += np.array(outer_padding)
if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
# 1) pad the inner region according inner_padding_factor
if inner_padding_factor > 0:
size_diff = tmp_crop_size * inner_padding_factor * 2
tmp_5pts += size_diff / 2
tmp_crop_size += np.round(size_diff).astype(np.int32)
# 2) resize the padded inner region
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
raise FaceWarpException('Must have (output_size - outer_padding)'
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
tmp_5pts = tmp_5pts * scale_factor
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
# tmp_5pts = tmp_5pts + size_diff / 2
tmp_crop_size = size_bf_outer_pad
# 3) add outer_padding to make output_size
reference_5point = tmp_5pts + np.array(outer_padding)
tmp_crop_size = output_size
return reference_5point
def get_affine_transform_matrix(src_pts, dst_pts):
"""
Function:
----------
get affine transform matrix 'tfm' from src_pts to dst_pts
Parameters:
----------
@src_pts: Kx2 np.array
source points matrix, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points matrix, each row is a pair of coordinates (x, y)
Returns:
----------
@tfm: 2x3 np.array
transform matrix from src_pts to dst_pts
"""
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
n_pts = src_pts.shape[0]
ones = np.ones((n_pts, 1), src_pts.dtype)
src_pts_ = np.hstack([src_pts, ones])
dst_pts_ = np.hstack([dst_pts, ones])
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
if rank == 3:
tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
elif rank == 2:
tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
return tfm
def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@src_img: 3x3 np.array
input image
@facial_pts: could be
1)a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
@reference_pts: could be
1) a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
or
3) None
if None, use default reference facial points
@crop_size: (w, h)
output face image size
@align_type: transform type, could be one of
1) 'similarity': use similarity transform
2) 'cv2_affine': use the first 3 points to do affine transform,
by calling cv2.getAffineTransform()
3) 'affine': use all points to do affine transform
Returns:
----------
@face_img: output face image with size (w, h) = @crop_size
"""
if reference_pts is None:
if crop_size[0] == 96 and crop_size[1] == 112:
reference_pts = REFERENCE_FACIAL_POINTS
else:
default_square = False
inner_padding_factor = 0
outer_padding = (0, 0)
output_size = crop_size
reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
default_square)
ref_pts = np.float32(reference_pts)
ref_pts_shp = ref_pts.shape
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
if ref_pts_shp[0] == 2:
ref_pts = ref_pts.T
src_pts = np.float32(facial_pts)
src_pts_shp = src_pts.shape
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
if src_pts_shp[0] == 2:
src_pts = src_pts.T
if src_pts.shape != ref_pts.shape:
raise FaceWarpException('facial_pts and reference_pts must have the same shape')
if align_type == 'cv2_affine':
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
elif align_type == 'affine':
tfm = get_affine_transform_matrix(src_pts, ref_pts)
else:
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
return face_img

View File

@ -0,0 +1,317 @@
import numpy as np
from numpy.linalg import inv, lstsq
from numpy.linalg import matrix_rank as rank
from numpy.linalg import norm
class MatlabCp2tormException(Exception):
def __str__(self):
return 'In File {}:{}'.format(__file__, super.__str__(self))
def tformfwd(trans, uv):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@trans: 3x3 np.array
transform matrix
@uv: Kx2 np.array
each row is a pair of coordinates (x, y)
Returns:
----------
@xy: Kx2 np.array
each row is a pair of transformed coordinates (x, y)
"""
uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
xy = np.dot(uv, trans)
xy = xy[:, 0:-1]
return xy
def tforminv(trans, uv):
"""
Function:
----------
apply the inverse of affine transform 'trans' to uv
Parameters:
----------
@trans: 3x3 np.array
transform matrix
@uv: Kx2 np.array
each row is a pair of coordinates (x, y)
Returns:
----------
@xy: Kx2 np.array
each row is a pair of inverse-transformed coordinates (x, y)
"""
Tinv = inv(trans)
xy = tformfwd(Tinv, uv)
return xy
def findNonreflectiveSimilarity(uv, xy, options=None):
options = {'K': 2}
K = options['K']
M = xy.shape[0]
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
X = np.vstack((tmp1, tmp2))
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
U = np.vstack((u, v))
# We know that X * r = U
if rank(X) >= 2 * K:
r, _, _, _ = lstsq(X, U, rcond=-1)
r = np.squeeze(r)
else:
raise Exception('cp2tform:twoUniquePointsReq')
sc = r[0]
ss = r[1]
tx = r[2]
ty = r[3]
Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
T = inv(Tinv)
T[:, 2] = np.array([0, 0, 1])
return T, Tinv
def findSimilarity(uv, xy, options=None):
options = {'K': 2}
# uv = np.array(uv)
# xy = np.array(xy)
# Solve for trans1
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
# Solve for trans2
# manually reflect the xy data across the Y-axis
xyR = xy
xyR[:, 0] = -1 * xyR[:, 0]
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
# manually reflect the tform to undo the reflection done on xyR
TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
trans2 = np.dot(trans2r, TreflectY)
# Figure out if trans1 or trans2 is better
xy1 = tformfwd(trans1, uv)
norm1 = norm(xy1 - xy)
xy2 = tformfwd(trans2, uv)
norm2 = norm(xy2 - xy)
if norm1 <= norm2:
return trans1, trans1_inv
else:
trans2_inv = inv(trans2)
return trans2, trans2_inv
def get_similarity_transform(src_pts, dst_pts, reflective=True):
"""
Function:
----------
Find Similarity Transform Matrix 'trans':
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y, 1] = [u, v, 1] * trans
Parameters:
----------
@src_pts: Kx2 np.array
source points, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points, each row is a pair of transformed
coordinates (x, y)
@reflective: True or False
if True:
use reflective similarity transform
else:
use non-reflective similarity transform
Returns:
----------
@trans: 3x3 np.array
transform matrix from uv to xy
trans_inv: 3x3 np.array
inverse of trans, transform matrix from xy to uv
"""
if reflective:
trans, trans_inv = findSimilarity(src_pts, dst_pts)
else:
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
return trans, trans_inv
def cvt_tform_mat_for_cv2(trans):
"""
Function:
----------
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
directly used by cv2.warpAffine():
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y].T = cv_trans * [u, v, 1].T
Parameters:
----------
@trans: 3x3 np.array
transform matrix from uv to xy
Returns:
----------
@cv2_trans: 2x3 np.array
transform matrix from src_pts to dst_pts, could be directly used
for cv2.warpAffine()
"""
cv2_trans = trans[:, 0:2].T
return cv2_trans
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
"""
Function:
----------
Find Similarity Transform Matrix 'cv2_trans' which could be
directly used by cv2.warpAffine():
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y].T = cv_trans * [u, v, 1].T
Parameters:
----------
@src_pts: Kx2 np.array
source points, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points, each row is a pair of transformed
coordinates (x, y)
reflective: True or False
if True:
use reflective similarity transform
else:
use non-reflective similarity transform
Returns:
----------
@cv2_trans: 2x3 np.array
transform matrix from src_pts to dst_pts, could be directly used
for cv2.warpAffine()
"""
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
cv2_trans = cvt_tform_mat_for_cv2(trans)
return cv2_trans
if __name__ == '__main__':
"""
u = [0, 6, -2]
v = [0, 3, 5]
x = [-1, 0, 4]
y = [-1, -10, 4]
# In Matlab, run:
#
# uv = [u'; v'];
# xy = [x'; y'];
# tform_sim=cp2tform(uv,xy,'similarity');
#
# trans = tform_sim.tdata.T
# ans =
# -0.0764 -1.6190 0
# 1.6190 -0.0764 0
# -3.2156 0.0290 1.0000
# trans_inv = tform_sim.tdata.Tinv
# ans =
#
# -0.0291 0.6163 0
# -0.6163 -0.0291 0
# -0.0756 1.9826 1.0000
# xy_m=tformfwd(tform_sim, u,v)
#
# xy_m =
#
# -3.2156 0.0290
# 1.1833 -9.9143
# 5.0323 2.8853
# uv_m=tforminv(tform_sim, x,y)
#
# uv_m =
#
# 0.5698 1.3953
# 6.0872 2.2733
# -2.6570 4.3314
"""
u = [0, 6, -2]
v = [0, 3, 5]
x = [-1, 0, 4]
y = [-1, -10, 4]
uv = np.array((u, v)).T
xy = np.array((x, y)).T
print('\n--->uv:')
print(uv)
print('\n--->xy:')
print(xy)
trans, trans_inv = get_similarity_transform(uv, xy)
print('\n--->trans matrix:')
print(trans)
print('\n--->trans_inv matrix:')
print(trans_inv)
print('\n---> apply transform to uv')
print('\nxy_m = uv_augmented * trans')
uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
xy_m = np.dot(uv_aug, trans)
print(xy_m)
print('\nxy_m = tformfwd(trans, uv)')
xy_m = tformfwd(trans, uv)
print(xy_m)
print('\n---> apply inverse transform to xy')
print('\nuv_m = xy_augmented * trans_inv')
xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
uv_m = np.dot(xy_aug, trans_inv)
print(uv_m)
print('\nuv_m = tformfwd(trans_inv, xy)')
uv_m = tformfwd(trans_inv, xy)
print(uv_m)
uv_m = tforminv(trans, xy)
print('\nuv_m = tforminv(trans, xy)')
print(uv_m)

View File

@ -0,0 +1,366 @@
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
from fooocus_extras.facexlib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
from fooocus_extras.facexlib.detection.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
from fooocus_extras.facexlib.detection.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
py_cpu_nms)
def generate_config(network_name):
cfg_mnet = {
'name': 'mobilenet0.25',
'min_sizes': [[16, 32], [64, 128], [256, 512]],
'steps': [8, 16, 32],
'variance': [0.1, 0.2],
'clip': False,
'loc_weight': 2.0,
'gpu_train': True,
'batch_size': 32,
'ngpu': 1,
'epoch': 250,
'decay1': 190,
'decay2': 220,
'image_size': 640,
'return_layers': {
'stage1': 1,
'stage2': 2,
'stage3': 3
},
'in_channel': 32,
'out_channel': 64
}
cfg_re50 = {
'name': 'Resnet50',
'min_sizes': [[16, 32], [64, 128], [256, 512]],
'steps': [8, 16, 32],
'variance': [0.1, 0.2],
'clip': False,
'loc_weight': 2.0,
'gpu_train': True,
'batch_size': 24,
'ngpu': 4,
'epoch': 100,
'decay1': 70,
'decay2': 90,
'image_size': 840,
'return_layers': {
'layer2': 1,
'layer3': 2,
'layer4': 3
},
'in_channel': 256,
'out_channel': 256
}
if network_name == 'mobile0.25':
return cfg_mnet
elif network_name == 'resnet50':
return cfg_re50
else:
raise NotImplementedError(f'network_name={network_name}')
class RetinaFace(nn.Module):
def __init__(self, network_name='resnet50', half=False, phase='test', device=None):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
super(RetinaFace, self).__init__()
self.half_inference = half
cfg = generate_config(network_name)
self.backbone = cfg['name']
self.model_name = f'retinaface_{network_name}'
self.cfg = cfg
self.phase = phase
self.target_size, self.max_size = 1600, 2150
self.resize, self.scale, self.scale1 = 1., None, None
self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=self.device)
self.reference = get_reference_facial_points(default_square=True)
# Build network.
backbone = None
if cfg['name'] == 'mobilenet0.25':
backbone = MobileNetV1()
self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
elif cfg['name'] == 'Resnet50':
import torchvision.models as models
backbone = models.resnet50(weights=None)
self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
in_channels_stage2 = cfg['in_channel']
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = cfg['out_channel']
self.fpn = FPN(in_channels_list, out_channels)
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)
self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
self.to(self.device)
self.eval()
if self.half_inference:
self.half()
def forward(self, inputs):
out = self.body(inputs)
if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
out = list(out.values())
# FPN
fpn = self.fpn(out)
# SSH
feature1 = self.ssh1(fpn[0])
feature2 = self.ssh2(fpn[1])
feature3 = self.ssh3(fpn[2])
features = [feature1, feature2, feature3]
bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
ldm_regressions = (torch.cat(tmp, dim=1))
if self.phase == 'train':
output = (bbox_regressions, classifications, ldm_regressions)
else:
output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
return output
def __detect_faces(self, inputs):
# get scale
height, width = inputs.shape[2:]
self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device)
tmp = [width, height, width, height, width, height, width, height, width, height]
self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
# forawrd
inputs = inputs.to(self.device)
if self.half_inference:
inputs = inputs.half()
loc, conf, landmarks = self(inputs)
# get priorbox
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
priors = priorbox.forward().to(self.device)
return loc, conf, landmarks, priors
# single image detection
def transform(self, image, use_origin_size):
# convert to opencv format
if isinstance(image, Image.Image):
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
image = image.astype(np.float32)
# testing scale
im_size_min = np.min(image.shape[0:2])
im_size_max = np.max(image.shape[0:2])
resize = float(self.target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size
if np.round(resize * im_size_max) > self.max_size:
resize = float(self.max_size) / float(im_size_max)
resize = 1 if use_origin_size else resize
# resize
if resize != 1:
image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
# convert to torch.tensor format
# image -= (104, 117, 123)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).unsqueeze(0)
return image, resize
def detect_faces(
self,
image,
conf_threshold=0.8,
nms_threshold=0.4,
use_origin_size=True,
):
image, self.resize = self.transform(image, use_origin_size)
image = image.to(self.device)
if self.half_inference:
image = image.half()
image = image - self.mean_tensor
loc, conf, landmarks, priors = self.__detect_faces(image)
boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
boxes = boxes * self.scale / self.resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
landmarks = landmarks * self.scale1 / self.resize
landmarks = landmarks.cpu().numpy()
# ignore low scores
inds = np.where(scores > conf_threshold)[0]
boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
# sort
order = scores.argsort()[::-1]
boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
# do NMS
bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(bounding_boxes, nms_threshold)
bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
# self.t['forward_pass'].toc()
# print(self.t['forward_pass'].average_time)
# import sys
# sys.stdout.flush()
return np.concatenate((bounding_boxes, landmarks), axis=1)
def __align_multi(self, image, boxes, landmarks, limit=None):
if len(boxes) < 1:
return [], []
if limit:
boxes = boxes[:limit]
landmarks = landmarks[:limit]
faces = []
for landmark in landmarks:
facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
faces.append(warped_face)
return np.concatenate((boxes, landmarks), axis=1), faces
def align_multi(self, img, conf_threshold=0.8, limit=None):
rlt = self.detect_faces(img, conf_threshold=conf_threshold)
boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
return self.__align_multi(img, boxes, landmarks, limit)
# batched detection
def batched_transform(self, frames, use_origin_size):
"""
Arguments:
frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
type=np.float32, BGR format).
use_origin_size: whether to use origin size.
"""
from_PIL = True if isinstance(frames[0], Image.Image) else False
# convert to opencv format
if from_PIL:
frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
frames = np.asarray(frames, dtype=np.float32)
# testing scale
im_size_min = np.min(frames[0].shape[0:2])
im_size_max = np.max(frames[0].shape[0:2])
resize = float(self.target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size
if np.round(resize * im_size_max) > self.max_size:
resize = float(self.max_size) / float(im_size_max)
resize = 1 if use_origin_size else resize
# resize
if resize != 1:
if not from_PIL:
frames = F.interpolate(frames, scale_factor=resize)
else:
frames = [
cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
for frame in frames
]
# convert to torch.tensor format
if not from_PIL:
frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
else:
frames = frames.transpose((0, 3, 1, 2))
frames = torch.from_numpy(frames)
return frames, resize
def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
"""
Arguments:
frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
type=np.uint8, BGR format).
conf_threshold: confidence threshold.
nms_threshold: nms threshold.
use_origin_size: whether to use origin size.
Returns:
final_bounding_boxes: list of np.array ([n_boxes, 5],
type=np.float32).
final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
"""
# self.t['forward_pass'].tic()
frames, self.resize = self.batched_transform(frames, use_origin_size)
frames = frames.to(self.device)
frames = frames - self.mean_tensor
b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
final_bounding_boxes, final_landmarks = [], []
# decode
priors = priors.unsqueeze(0)
b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
b_conf = b_conf[:, :, 1]
# index for selection
b_indice = b_conf > conf_threshold
# concat
b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
# ignore low scores
pred, landm = pred[inds, :], landm[inds, :]
if pred.shape[0] == 0:
final_bounding_boxes.append(np.array([], dtype=np.float32))
final_landmarks.append(np.array([], dtype=np.float32))
continue
# sort
# order = score.argsort(descending=True)
# box, landm, score = box[order], landm[order], score[order]
# to CPU
bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
# NMS
keep = py_cpu_nms(bounding_boxes, nms_threshold)
bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
# append
final_bounding_boxes.append(bounding_boxes)
final_landmarks.append(landmarks)
# self.t['forward_pass'].toc(average=True)
# self.batch_time += self.t['forward_pass'].diff
# self.total_frame += len(frames)
# print(self.batch_time / self.total_frame)
return final_bounding_boxes, final_landmarks

View File

@ -0,0 +1,196 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_bn(inp, oup, stride=1, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))
def conv_bn_no_relu(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
)
def conv_bn1X1(inp, oup, stride, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))
def conv_dw(inp, oup, stride, leaky=0.1):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
)
class SSH(nn.Module):
def __init__(self, in_channel, out_channel):
super(SSH, self).__init__()
assert out_channel % 4 == 0
leaky = 0
if (out_channel <= 64):
leaky = 0.1
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
def forward(self, input):
conv3X3 = self.conv3X3(input)
conv5X5_1 = self.conv5X5_1(input)
conv5X5 = self.conv5X5_2(conv5X5_1)
conv7X7_2 = self.conv7X7_2(conv5X5_1)
conv7X7 = self.conv7x7_3(conv7X7_2)
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
out = F.relu(out)
return out
class FPN(nn.Module):
def __init__(self, in_channels_list, out_channels):
super(FPN, self).__init__()
leaky = 0
if (out_channels <= 64):
leaky = 0.1
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
def forward(self, input):
# names = list(input.keys())
# input = list(input.values())
output1 = self.output1(input[0])
output2 = self.output2(input[1])
output3 = self.output3(input[2])
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
output2 = output2 + up3
output2 = self.merge2(output2)
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
output1 = output1 + up2
output1 = self.merge1(output1)
out = [output1, output2, output3]
return out
class MobileNetV1(nn.Module):
def __init__(self):
super(MobileNetV1, self).__init__()
self.stage1 = nn.Sequential(
conv_bn(3, 8, 2, leaky=0.1), # 3
conv_dw(8, 16, 1), # 7
conv_dw(16, 32, 2), # 11
conv_dw(32, 32, 1), # 19
conv_dw(32, 64, 2), # 27
conv_dw(64, 64, 1), # 43
)
self.stage2 = nn.Sequential(
conv_dw(64, 128, 2), # 43 + 16 = 59
conv_dw(128, 128, 1), # 59 + 32 = 91
conv_dw(128, 128, 1), # 91 + 32 = 123
conv_dw(128, 128, 1), # 123 + 32 = 155
conv_dw(128, 128, 1), # 155 + 32 = 187
conv_dw(128, 128, 1), # 187 + 32 = 219
)
self.stage3 = nn.Sequential(
conv_dw(128, 256, 2), # 219 +3 2 = 241
conv_dw(256, 256, 1), # 241 + 64 = 301
)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, 1000)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
# x = self.model(x)
x = x.view(-1, 256)
x = self.fc(x)
return x
class ClassHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(ClassHead, self).__init__()
self.num_anchors = num_anchors
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 2)
class BboxHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(BboxHead, self).__init__()
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 4)
class LandmarkHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(LandmarkHead, self).__init__()
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 10)
def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
classhead = nn.ModuleList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels, anchor_num))
return classhead
def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
bboxhead = nn.ModuleList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels, anchor_num))
return bboxhead
def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
landmarkhead = nn.ModuleList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
return landmarkhead

View File

@ -0,0 +1,421 @@
import numpy as np
import torch
import torchvision
from itertools import product as product
from math import ceil
class PriorBox(object):
def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
self.name = 's'
def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output
def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
keep = torchvision.ops.nms(
boxes=torch.Tensor(dets[:, :4]),
scores=torch.Tensor(dets[:, 4]),
iou_threshold=thresh,
)
return list(keep)
def point_form(boxes):
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
representation for comparison to point form ground truth data.
Args:
boxes: (tensor) center-size default boxes from priorbox layers.
Return:
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
"""
return torch.cat(
(
boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
boxes[:, :2] + boxes[:, 2:] / 2),
1) # xmax, ymax
def center_size(boxes):
""" Convert prior_boxes to (cx, cy, w, h)
representation for comparison to center-size form ground truth data.
Args:
boxes: (tensor) point_form boxes
Return:
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
"""
return torch.cat(
(boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
boxes[:, 2:] - boxes[:, :2],
1) # w, h
def intersect(box_a, box_b):
""" We resize both tensors to [A,B,2] without new malloc:
[A,2] -> [A,1,2] -> [A,B,2]
[B,2] -> [1,B,2] -> [A,B,2]
Then we compute the area of intersect between box_a and box_b.
Args:
box_a: (tensor) bounding boxes, Shape: [A,4].
box_b: (tensor) bounding boxes, Shape: [B,4].
Return:
(tensor) intersection area, Shape: [A,B].
"""
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
return inter[:, :, 0] * inter[:, :, 1]
def jaccard(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
is simply the intersection over union of two boxes. Here we operate on
ground truth boxes and default boxes.
E.g.:
A B / A B = A B / (area(A) + area(B) - A B)
Args:
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
Return:
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
"""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
union = area_a + area_b - inter
return inter / union # [A,B]
def matrix_iou(a, b):
"""
return iou of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
def matrix_iof(a, b):
"""
return iof of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
return area_i / np.maximum(area_a[:, np.newaxis], 1)
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
"""Match each prior box with the ground truth box of the highest jaccard
overlap, encode the bounding boxes, then return the matched indices
corresponding to both confidence and location preds.
Args:
threshold: (float) The overlap threshold used when matching boxes.
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
variances: (tensor) Variances corresponding to each prior coord,
Shape: [num_priors, 4].
labels: (tensor) All the class labels for the image, Shape: [num_obj].
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
loc_t: (tensor) Tensor to be filled w/ encoded location targets.
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
idx: (int) current batch index
Return:
The matched indices corresponding to 1)location 2)confidence
3)landm preds.
"""
# jaccard index
overlaps = jaccard(truths, point_form(priors))
# (Bipartite Matching)
# [1,num_objects] best prior for each ground truth
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# ignore hard gt
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
loc_t[idx] = 0
conf_t[idx] = 0
return
# [1,num_priors] best ground truth for each prior
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_idx_filter.squeeze_(1)
best_prior_overlap.squeeze_(1)
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
# TODO refactor: index best_prior_idx with long tensor
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
best_truth_idx[best_prior_idx[j]] = j
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
loc = encode(matches, priors, variances)
matches_landm = landms[best_truth_idx]
landm = encode_landm(matches_landm, priors, variances)
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
conf_t[idx] = conf # [num_priors] top class label for each prior
landm_t[idx] = landm
def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 4].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded boxes (tensor), Shape: [num_priors, 4]
"""
# dist b/t match center and prior's center
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, 2:])
# match wh / prior wh
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# return target for smooth_l1_loss
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
def encode_landm(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 10].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded landm (tensor), Shape: [num_priors, 10]
"""
# dist b/t match center and prior's center
matched = torch.reshape(matched, (matched.size(0), 5, 2))
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, :, 2:])
# g_cxcy /= priors[:, :, 2:]
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
# return target for smooth_l1_loss
return g_cxcy
# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
tmp = (
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
)
landms = torch.cat(tmp, dim=1)
return landms
def batched_decode(b_loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
b_loc (tensor): location predictions for loc layers,
Shape: [num_batches,num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [1,num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = (
priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
)
boxes = torch.cat(boxes, dim=2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes
def batched_decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_batches,num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [1,num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
landms = (
priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
)
landms = torch.cat(landms, dim=2)
return landms
def log_sum_exp(x):
"""Utility function for computing log_sum_exp while determining
This will be used to determine unaveraged confidence loss across
all examples in a batch.
Args:
x (Variable(tensor)): conf_preds from conf layers
"""
x_max = x.data.max()
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
# Original author: Francisco Massa:
# https://github.com/fmassa/object-detection.torch
# Ported to PyTorch by Max deGroot (02/01/2017)
def nms(boxes, scores, overlap=0.5, top_k=200):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
top_k: (int) The Maximum number of box preds to consider.
Return:
The indices of the kept boxes with respect to num_priors.
"""
keep = torch.Tensor(scores.size(0)).fill_(0).long()
if boxes.numel() == 0:
return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
# I = I[v >= 0.01]
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
# keep = torch.Tensor()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
# keep.append(i)
keep[count] = i
count += 1
if idx.size(0) == 1:
break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w * h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter / union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return keep, count

View File

@ -0,0 +1,24 @@
import torch
from fooocus_extras.facexlib.utils import load_file_from_url
from .bisenet import BiSeNet
from .parsenet import ParseNet
def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None):
if model_name == 'bisenet':
model = BiSeNet(num_class=19)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
elif model_name == 'parsenet':
model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
return model

View File

@ -0,0 +1,140 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet18
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_chan)
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, num_class):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
def forward(self, x):
feat = self.conv(x)
out = self.conv_out(feat)
return out, feat
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
class ContextPath(nn.Module):
def __init__(self):
super(ContextPath, self).__init__()
self.resnet = ResNet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
def forward(self, x):
feat8, feat16, feat32 = self.resnet(x)
h8, w8 = feat8.size()[2:]
h16, w16 = feat16.size()[2:]
h32, w32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
class BiSeNet(nn.Module):
def __init__(self, num_class):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, num_class)
self.conv_out16 = BiSeNetOutput(128, 64, num_class)
self.conv_out32 = BiSeNetOutput(128, 64, num_class)
def forward(self, x, return_feat=False):
h, w = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
out, feat = self.conv_out(feat_fuse)
out16, feat16 = self.conv_out16(feat_cp8)
out32, feat32 = self.conv_out32(feat_cp16)
out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
if return_feat:
feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
return out, out16, out32, feat, feat16, feat32
else:
return out, out16, out32

View File

@ -0,0 +1,194 @@
"""Modified from https://github.com/chaofengc/PSFRGAN
"""
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
class NormLayer(nn.Module):
"""Normalization Layers.
Args:
channels: input channels, for batch norm and instance norm.
input_size: input shape without batch size, for layer norm.
"""
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
super(NormLayer, self).__init__()
norm_type = norm_type.lower()
self.norm_type = norm_type
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(channels, affine=True)
elif norm_type == 'in':
self.norm = nn.InstanceNorm2d(channels, affine=False)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(32, channels, affine=True)
elif norm_type == 'pixel':
self.norm = lambda x: F.normalize(x, p=2, dim=1)
elif norm_type == 'layer':
self.norm = nn.LayerNorm(normalize_shape)
elif norm_type == 'none':
self.norm = lambda x: x * 1.0
else:
assert 1 == 0, f'Norm type {norm_type} not support.'
def forward(self, x, ref=None):
if self.norm_type == 'spade':
return self.norm(x, ref)
else:
return self.norm(x)
class ReluLayer(nn.Module):
"""Relu Layer.
Args:
relu type: type of relu layer, candidates are
- ReLU
- LeakyReLU: default relu slope 0.2
- PRelu
- SELU
- none: direct pass
"""
def __init__(self, channels, relu_type='relu'):
super(ReluLayer, self).__init__()
relu_type = relu_type.lower()
if relu_type == 'relu':
self.func = nn.ReLU(True)
elif relu_type == 'leakyrelu':
self.func = nn.LeakyReLU(0.2, inplace=True)
elif relu_type == 'prelu':
self.func = nn.PReLU(channels)
elif relu_type == 'selu':
self.func = nn.SELU(True)
elif relu_type == 'none':
self.func = lambda x: x * 1.0
else:
assert 1 == 0, f'Relu type {relu_type} not support.'
def forward(self, x):
return self.func(x)
class ConvLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
scale='none',
norm_type='none',
relu_type='none',
use_pad=True,
bias=True):
super(ConvLayer, self).__init__()
self.use_pad = use_pad
self.norm_type = norm_type
if norm_type in ['bn']:
bias = False
stride = 2 if scale == 'down' else 1
self.scale_func = lambda x: x
if scale == 'up':
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
self.relu = ReluLayer(out_channels, relu_type)
self.norm = NormLayer(out_channels, norm_type=norm_type)
def forward(self, x):
out = self.scale_func(x)
if self.use_pad:
out = self.reflection_pad(out)
out = self.conv2d(out)
out = self.norm(out)
out = self.relu(out)
return out
class ResidualBlock(nn.Module):
"""
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
super(ResidualBlock, self).__init__()
if scale == 'none' and c_in == c_out:
self.shortcut_func = lambda x: x
else:
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
scale_conf = scale_config_dict[scale]
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
def forward(self, x):
identity = self.shortcut_func(x)
res = self.conv1(x)
res = self.conv2(res)
return identity + res
class ParseNet(nn.Module):
def __init__(self,
in_size=128,
out_size=128,
min_feat_size=32,
base_ch=64,
parsing_ch=19,
res_depth=10,
relu_type='LeakyReLU',
norm_type='bn',
ch_range=[32, 256]):
super().__init__()
self.res_depth = res_depth
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
min_ch, max_ch = ch_range
ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
min_feat_size = min(in_size, min_feat_size)
down_steps = int(np.log2(in_size // min_feat_size))
up_steps = int(np.log2(out_size // min_feat_size))
# =============== define encoder-body-decoder ====================
self.encoder = []
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
head_ch = base_ch
for i in range(down_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
head_ch = head_ch * 2
self.body = []
for i in range(res_depth):
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
self.decoder = []
for i in range(up_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
head_ch = head_ch // 2
self.encoder = nn.Sequential(*self.encoder)
self.body = nn.Sequential(*self.body)
self.decoder = nn.Sequential(*self.decoder)
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
def forward(self, x):
feat = self.encoder(x)
x = feat + self.body(feat)
x = self.decoder(x)
out_img = self.out_img_conv(x)
out_mask = self.out_mask_conv(x)
return out_mask, out_img

View File

@ -0,0 +1,69 @@
import torch.nn as nn
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum - 1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32

View File

@ -0,0 +1,7 @@
from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
from .misc import img2tensor, load_file_from_url, scandir
__all__ = [
'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back',
'img2tensor', 'scandir'
]

View File

@ -0,0 +1,374 @@
import cv2
import numpy as np
import os
import torch
from torchvision.transforms.functional import normalize
from fooocus_extras.facexlib.detection import init_detection_model
from fooocus_extras.facexlib.parsing import init_parsing_model
from fooocus_extras.facexlib.utils.misc import img2tensor, imwrite
def get_largest_face(det_faces, h, w):
def get_location(val, length):
if val < 0:
return 0
elif val > length:
return length
else:
return val
face_areas = []
for det_face in det_faces:
left = get_location(det_face[0], w)
right = get_location(det_face[2], w)
top = get_location(det_face[1], h)
bottom = get_location(det_face[3], h)
face_area = (right - left) * (bottom - top)
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
return det_faces[largest_idx], largest_idx
def get_center_face(det_faces, h=0, w=0, center=None):
if center is not None:
center = np.array(center)
else:
center = np.array([w / 2, h / 2])
center_dist = []
for det_face in det_faces:
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
dist = np.linalg.norm(face_center - center)
center_dist.append(dist)
center_idx = center_dist.index(min(center_dist))
return det_faces[center_idx], center_idx
class FaceRestoreHelper(object):
"""Helper for the face restoration pipeline (base class)."""
def __init__(self,
upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
template_3points=False,
pad_blur=False,
use_parse=False,
device=None,
model_rootpath=None):
self.template_3points = template_3points # improve robustness
self.upscale_factor = upscale_factor
# the cropped face ratio based on the square face
self.crop_ratio = crop_ratio # (h, w)
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
if self.template_3points:
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
else:
# standard 5 landmarks for FFHQ faces with 512 x 512
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
[201.26117, 371.41043], [313.08905, 371.15118]])
self.face_template = self.face_template * (face_size / 512.0)
if self.crop_ratio[0] > 1:
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
if self.crop_ratio[1] > 1:
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
self.save_ext = save_ext
self.pad_blur = pad_blur
if self.pad_blur is True:
self.template_3points = False
self.all_landmarks_5 = []
self.det_faces = []
self.affine_matrices = []
self.inverse_affine_matrices = []
self.cropped_faces = []
self.restored_faces = []
self.pad_input_imgs = []
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# init face detection model
self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath)
# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath)
def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
def read_image(self, img):
"""img can be image path or cv2 loaded image."""
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
if isinstance(img, str):
img = cv2.imread(img)
if np.max(img) > 256: # 16-bit image
img = img / 65535 * 255
if len(img.shape) == 2: # gray image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # RGBA image with alpha channel
img = img[:, :, 0:3]
self.input_img = img
def get_face_landmarks_5(self,
only_keep_largest=False,
only_center_face=False,
resize=None,
blur_ratio=0.01,
eye_dist_threshold=None):
if resize is None:
scale = 1
input_img = self.input_img
else:
h, w = self.input_img.shape[0:2]
scale = min(h, w) / resize
h, w = int(h / scale), int(w / scale)
input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)
with torch.no_grad():
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
for bbox in bboxes:
# remove faces with too small eye distance: side faces or too small faces
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
continue
if self.template_3points:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
else:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
self.all_landmarks_5.append(landmark)
self.det_faces.append(bbox[0:5])
if len(self.det_faces) == 0:
return 0
if only_keep_largest:
h, w, _ = self.input_img.shape
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
elif only_center_face:
h, w, _ = self.input_img.shape
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
# pad blurry images
if self.pad_blur:
self.pad_input_imgs = []
for landmarks in self.all_landmarks_5:
# get landmarks
eye_left = landmarks[0, :]
eye_right = landmarks[1, :]
eye_avg = (eye_left + eye_right) * 0.5
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1.5
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
border = max(int(np.rint(qsize * 0.1)), 3)
# get pad
# pad: (width_left, height_top, width_right, height_bottom)
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = [
max(-pad[0] + border, 1),
max(-pad[1] + border, 1),
max(pad[2] - self.input_img.shape[0] + border, 1),
max(pad[3] - self.input_img.shape[1] + border, 1)
]
if max(pad) > 1:
# pad image
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
# modify landmark coords
landmarks[:, 0] += pad[0]
landmarks[:, 1] += pad[1]
# blur pad images
h, w, _ = pad_img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * blur_ratio)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
pad_img = pad_img.astype('float32')
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
self.pad_input_imgs.append(pad_img)
else:
self.pad_input_imgs.append(np.copy(self.input_img))
return len(self.all_landmarks_5)
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
"""Align and warp faces with face template.
"""
if self.pad_blur:
assert len(self.pad_input_imgs) == len(
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
for idx, landmark in enumerate(self.all_landmarks_5):
# use 5 landmarks to get affine matrix
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
self.affine_matrices.append(affine_matrix)
# warp and crop faces
if border_mode == 'constant':
border_mode = cv2.BORDER_CONSTANT
elif border_mode == 'reflect101':
border_mode = cv2.BORDER_REFLECT101
elif border_mode == 'reflect':
border_mode = cv2.BORDER_REFLECT
if self.pad_blur:
input_img = self.pad_input_imgs[idx]
else:
input_img = self.input_img
cropped_face = cv2.warpAffine(
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
self.cropped_faces.append(cropped_face)
# save the cropped face
if save_cropped_path is not None:
path = os.path.splitext(save_cropped_path)[0]
save_path = f'{path}_{idx:02d}.{self.save_ext}'
imwrite(cropped_face, save_path)
def get_inverse_affine(self, save_inverse_affine_path=None):
"""Get inverse affine matrix."""
for idx, affine_matrix in enumerate(self.affine_matrices):
inverse_affine = cv2.invertAffineTransform(affine_matrix)
inverse_affine *= self.upscale_factor
self.inverse_affine_matrices.append(inverse_affine)
# save inverse affine matrices
if save_inverse_affine_path is not None:
path, _ = os.path.splitext(save_inverse_affine_path)
save_path = f'{path}_{idx:02d}.pth'
torch.save(inverse_affine, save_path)
def add_restored_face(self, face):
self.restored_faces.append(face)
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
h, w, _ = self.input_img.shape
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
if upsample_img is None:
# simply resize the background
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
else:
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
assert len(self.restored_faces) == len(
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
# Add an offset to inverse affine matrix, for more precise back alignment
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
if self.use_parse:
# inference
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).to(self.device)
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()
mask = np.zeros(out.shape)
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
for idx, color in enumerate(MASK_COLORMAP):
mask[out == idx] = color
# blur the mask
mask = cv2.GaussianBlur(mask, (101, 101), 11)
mask = cv2.GaussianBlur(mask, (101, 101), 11)
# remove the black borders
thres = 10
mask[:thres, :] = 0
mask[-thres:, :] = 0
mask[:, :thres] = 0
mask[:, -thres:] = 0
mask = mask / 255.
mask = cv2.resize(mask, restored_face.shape[:2])
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
inv_soft_mask = mask[:, :, None]
pasted_face = inv_restored
else: # use square parse maps
mask = np.ones(self.face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
alpha = upsample_img[:, :, 3:]
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
else:
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
if np.max(upsample_img) > 256: # 16-bit image
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
if save_path is not None:
path = os.path.splitext(save_path)[0]
save_path = f'{path}.{self.save_ext}'
imwrite(upsample_img, save_path)
return upsample_img
def clean_all(self):
self.all_landmarks_5 = []
self.restored_faces = []
self.affine_matrices = []
self.cropped_faces = []
self.inverse_affine_matrices = []
self.det_faces = []
self.pad_input_imgs = []

View File

@ -0,0 +1,250 @@
import cv2
import numpy as np
import torch
def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
left, top, right, bot = bbox
width = right - left
height = bot - top
if preserve_aspect:
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
else:
width_increase = height_increase = increase_area
left = int(left - width_increase * width)
top = int(top - height_increase * height)
right = int(right + width_increase * width)
bot = int(bot + height_increase * height)
return (left, top, right, bot)
def get_valid_bboxes(bboxes, h, w):
left = max(bboxes[0], 0)
top = max(bboxes[1], 0)
right = min(bboxes[2], w)
bottom = min(bboxes[3], h)
return (left, top, right, bottom)
def align_crop_face_landmarks(img,
landmarks,
output_size,
transform_size=None,
enable_padding=True,
return_inverse_affine=False,
shrink_ratio=(1, 1)):
"""Align and crop face with landmarks.
The output_size and transform_size are based on width. The height is
adjusted based on shrink_ratio_h/shring_ration_w.
Modified from:
https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
Args:
img (Numpy array): Input image.
landmarks (Numpy array): 5 or 68 or 98 landmarks.
output_size (int): Output face size.
transform_size (ing): Transform size. Usually the four time of
output_size.
enable_padding (float): Default: True.
shrink_ratio (float | tuple[float] | list[float]): Shring the whole
face for height and width (crop larger area). Default: (1, 1).
Returns:
(Numpy array): Cropped face.
"""
lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
if isinstance(shrink_ratio, (float, int)):
shrink_ratio = (shrink_ratio, shrink_ratio)
if transform_size is None:
transform_size = output_size * 4
# Parse landmarks
lm = np.array(landmarks)
if lm.shape[0] == 5 and lm_type == 'retinaface_5':
eye_left = lm[0]
eye_right = lm[1]
mouth_avg = (lm[3] + lm[4]) * 0.5
elif lm.shape[0] == 5 and lm_type == 'dlib_5':
lm_eye_left = lm[2:4]
lm_eye_right = lm[0:2]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = lm[4]
elif lm.shape[0] == 68:
lm_eye_left = lm[36:42]
lm_eye_right = lm[42:48]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = (lm[48] + lm[54]) * 0.5
elif lm.shape[0] == 98:
lm_eye_left = lm[60:68]
lm_eye_right = lm[68:76]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = (lm[76] + lm[82]) * 0.5
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1 # TODO: you can edit it to get larger rect
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
x *= shrink_ratio[1] # width
y *= shrink_ratio[0] # height
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
quad_ori = np.copy(quad)
# Shrink, for large face
# TODO: do we really need shrink
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
h, w = img.shape[0:2]
rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
quad /= shrink
qsize /= shrink
# Crop
h, w = img.shape[0:2]
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
img = img[crop[1]:crop[3], crop[0]:crop[2], :]
quad -= crop[0:2]
# Pad
# pad: (width_left, height_top, width_right, height_bottom)
h, w = img.shape[0:2]
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w = img.shape[0:2]
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * 0.02)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
img = img.astype('float32')
img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
img = np.clip(img, 0, 255) # float32, [0, 255]
quad += pad[:2]
# Transform use cv2
h_ratio = shrink_ratio[0] / shrink_ratio[1]
dst_h, dst_w = int(transform_size * h_ratio), transform_size
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
cropped_face = cv2.warpAffine(
img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
if output_size < transform_size:
cropped_face = cv2.resize(
cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
if return_inverse_affine:
dst_h, dst_w = int(output_size * h_ratio), output_size
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(
quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
inverse_affine = cv2.invertAffineTransform(affine_matrix)
else:
inverse_affine = None
return cropped_face, inverse_affine
def paste_face_back(img, face, inverse_affine):
h, w = img.shape[0:2]
face_h, face_w = face.shape[0:2]
inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
mask = np.ones((face_h, face_w, 3), dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
# remove the black borders
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
inv_restored_remove_border = inv_mask_erosion * inv_restored
total_face_area = np.sum(inv_mask_erosion) // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
# float32, [0, 255]
return img
if __name__ == '__main__':
import os
from fooocus_extras.facexlib.detection import init_detection_model
from fooocus_extras.facexlib.utils.face_restoration_helper import get_largest_face
from fooocus_extras.facexlib.visualization import visualize_detection
img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
img_name = os.splitext(os.path.basename(img_path))[0]
# initialize model
det_net = init_detection_model('retinaface_resnet50', half=False)
img_ori = cv2.imread(img_path)
h, w = img_ori.shape[0:2]
# if larger than 800, scale it
scale = max(h / 800, w / 800)
if scale > 1:
img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
with torch.no_grad():
bboxes = det_net.detect_faces(img, 0.97)
if scale > 1:
bboxes *= scale # the score is incorrect
bboxes = get_largest_face(bboxes, h, w)[0]
visualize_detection(img_ori, [bboxes], f'tmp/{img_name}_det.png')
landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
cropped_face, inverse_affine = align_crop_face_landmarks(
img_ori,
landmarks,
output_size=512,
transform_size=None,
enable_padding=True,
return_inverse_affine=True,
shrink_ratio=(1, 1))
cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
img = paste_face_back(img_ori, cropped_face, inverse_affine)
cv2.imwrite(f'tmp/{img_name}_back.png', img)

View File

@ -0,0 +1,118 @@
import cv2
import os
import os.path as osp
import torch
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
return cv2.imwrite(file_path, img, params)
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
if save_dir is None:
save_dir = os.path.join(ROOT_DIR, model_dir)
os.makedirs(save_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(save_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative paths.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)

View File

@ -1 +1 @@
version = '2.1.796' version = '2.1.797'

View File

@ -15,4 +15,3 @@ gradio==3.41.2
pygit2==1.12.2 pygit2==1.12.2
opencv-contrib-python==4.8.0.74 opencv-contrib-python==4.8.0.74
httpx==0.24.1 httpx==0.24.1
facexlib==0.3.0