类继承关系

ultralytics/data/base.py 有BaseDataset类,这个类是数据集处理的基类。

包括YOLODataset、SemanticDataset都继承这个类

基类方法分析

get_img_files

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_img_files(self, img_path):
"""Read image files.""" # 定义一个名为get_img_files的方法,用于读取图像文件

try:
f = [] # image files # 创建一个空列表f,用于存储找到的图像文件路径

# 下面的循环处理img_path,这可以是单个路径或路径列表
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic # 使用Path处理路径,以便跨操作系统兼容

if p.is_dir(): # dir # 如果p是一个目录
# 使用glob模块搜索所有子目录中的所有文件,并将它们添加到列表f中
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# F = list(p.rglob('*.*')) # pathlib # 这是另一种使用pathlib模块的方法

elif p.is_file(): # file # 如果p是一个文件
# 打开文件,并读取其中的内容,这些内容通常是图像文件的路径
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
# 对于每个文件路径,转换为绝对路径,并添加到列表f中
f += [x.replace('./', parent) if x.startswith('./') else x for x in t]

else:
# 如果路径既不是文件也不是目录,则抛出一个异常
raise FileNotFoundError(f'{self.prefix}{p} does not exist')

# 对找到的文件进行筛选,只保留图像格式的文件,这里IMG_FORMATS是一个包含图像文件扩展名的列表
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)

# 确保找到了图像文件,否则抛出异常
assert im_files, f'{self.prefix}No images found in {img_path}'

except Exception as e:
# 如果在处理过程中发生任何异常,抛出一个异常,指明无法从给定路径加载数据
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e

if self.fraction < 1:
# 如果属性fraction小于1,那么只返回一部分图像文件
im_files = im_files[:round(len(im_files) * self.fraction)]

return im_files # 返回找到的图像文件列表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

def update_labels(self, include_class: Optional[list]):
"""Update labels to include only these classes (optional).""" # 定义方法update_labels,用于更新标签以仅包含指定的类别

include_class_array = np.array(include_class).reshape(1, -1) # 将包含类别的列表转换为numpy数组,并重塑为1行多列的形式

for i in range(len(self.labels)): # 遍历self.labels中的每个元素
if include_class is not None: # 如果指定了包含的类别
cls = self.labels[i]['cls'] # 获取当前标签的类别
bboxes = self.labels[i]['bboxes'] # 获取当前标签的边界框
segments = self.labels[i]['segments'] # 获取当前标签的分割信息
keypoints = self.labels[i]['keypoints'] # 获取当前标签的关键点

j = (cls == include_class_array).any(1) # 检查cls中的每个元素是否在include_class_array中,返回布尔数组

self.labels[i]['cls'] = cls[j] # 更新类别,只保留包含在include_class_array中的类别
self.labels[i]['bboxes'] = bboxes[j] # 更新边界框,与类别对应

if segments: # 如果有分割信息
# 只保留对应于include_class_array中类别的分割信息
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]

if keypoints is not None: # 如果有关键点信息
# 只保留对应于include_class_array中类别的关键点
self.labels[i]['keypoints'] = keypoints[j]

if self.single_cls: # 如果设置了single_cls属性
self.labels[i]['cls'][:, 0] = 0 # 将所有类别设置为0(可能用于单类别任务)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def load_image(self, i, rect_mode=True):
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" # 定义方法load_image,用于从数据集中加载一张图像

im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] # 获取图像、图像文件路径和npy文件路径

if im is None: # not cached in RAM # 如果图像不在内存中缓存
if fn.exists(): # load npy # 如果npy文件存在
try:
im = np.load(fn) # 尝试从npy文件加载图像
except Exception as e: # 如果加载失败
# 记录警告信息,并删除损坏的npy文件
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
Path(fn).unlink(missing_ok=True)
im = cv2.imread(f) # 尝试从原始图像文件读取,BGR格式
else: # 如果没有npy文件
im = cv2.imread(f) # 直接从图像文件读取,BGR格式

if im is None:
# 如果无法加载图像,抛出文件未找到的异常
raise FileNotFoundError(f'Image Not Found {f}')

h0, w0 = im.shape[:2] # orig hw # 获取原始图像的高度和宽度

# 根据rect_mode参数决定如何调整图像尺寸
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
r = self.imgsz / max(h0, w0) # 计算缩放比例
if r != 1: # 如果需要缩放
# 计算新的高度和宽度,并保证它们不超过imgsz
w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) # 调整图像尺寸
elif not (h0 == w0 == self.imgsz): # 如果不是方形或大小不等于imgsz
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR) # 拉伸图像到正方形

# 如果设置了增强,并且缓存不满,则将图像添加到缓存中
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # 缓存图像及其原始和调整后的尺寸
self.buffer.append(i) # 将索引添加到缓存缓冲区
if len(self.buffer) >= self.max_buffer_length: # 如果缓存满了
j = self.buffer.pop(0) # 移除最旧的索引
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None # 清除对应的缓存

return im, (h0, w0), im.shape[:2] # 返回图像及其原始和调整后的尺寸

return self.ims[i], self.im_hw0[i], self.im_hw[i] # 如果图像已在缓存中,直接返回

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def load_image(self, i, rect_mode=True):
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" # 定义方法load_image,用于从数据集中加载一张图像

im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] # 获取图像、图像文件路径和npy文件路径

if im is None: # not cached in RAM # 如果图像不在内存中缓存
if fn.exists(): # load npy # 如果npy文件存在
try:
im = np.load(fn) # 尝试从npy文件加载图像
except Exception as e: # 如果加载失败
# 记录警告信息,并删除损坏的npy文件
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}')
Path(fn).unlink(missing_ok=True)
im = cv2.imread(f) # 尝试从原始图像文件读取,BGR格式
else: # 如果没有npy文件
im = cv2.imread(f) # 直接从图像文件读取,BGR格式

if im is None:
# 如果无法加载图像,抛出文件未找到的异常
raise FileNotFoundError(f'Image Not Found {f}')

h0, w0 = im.shape[:2] # orig hw # 获取原始图像的高度和宽度

# 根据rect_mode参数决定如何调整图像尺寸
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
r = self.imgsz / max(h0, w0) # 计算缩放比例
if r != 1: # 如果需要缩放
# 计算新的高度和宽度,并保证它们不超过imgsz
w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) # 调整图像尺寸
elif not (h0 == w0 == self.imgsz): # 如果不是方形或大小不等于imgsz
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR) # 拉伸图像到正方形

# 如果设置了增强,并且缓存不满,则将图像添加到缓存中
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # 缓存图像及其原始和调整后的尺寸
self.buffer.append(i) # 将索引添加到缓存缓冲区
if len(self.buffer) >= self.max_buffer_length: # 如果缓存满了
j = self.buffer.pop(0) # 移除最旧的索引
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None # 清除对应的缓存

return im, (h0, w0), im.shape[:2] # 返回图像及其原始和调整后的尺寸

return self.ims[i], self.im_hw0[i], self.im_hw[i] # 如果图像已在缓存中,直接返回

cache_images

这段代码展示了如何使用Python的线程池(ThreadPool)来并行处理数据加载任务,这样可以显著提高数据处理的效率。同时,它还使用了TQDM库来显示一个友好的进度条,帮助用户了解当前缓存进程的状态。这种方法在处理大量数据时尤其有效,可以减少等待时间并提高用户体验。check_cache_ram:通过计算部分图像的平均内存占用并将其外推到整个数据集,来估算缓存整个数据集所需的内存大小。使用psutil 库来获取系统的内存信息,然后根据可用内存和所需内存来决定是否进行缓存。这样的方法有助于在资源有限的环境中有效管理内存使用,避免因内存不足而导致的程序崩溃。如果您对代码中的某个部分有疑问,欢迎继续询问。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def cache_images(self, cache):
"""Cache images to memory or disk.""" # 定义cache_images方法,用于将图像缓存到内存或磁盘

b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
# 初始化用于跟踪缓存图像大小的变量,gb是1GB的字节数

fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
# 根据缓存类型选择相应的函数:缓存到磁盘或加载图像到内存

with ThreadPool(NUM_THREADS) as pool: # 创建一个线程池
results = pool.imap(fcn, range(self.ni)) # 使用线程池并行处理图像缓存任务
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
# 创建一个进度条来跟踪和显示缓存进程

for i, x in pbar: # 遍历每个图像的缓存结果
if cache == 'disk': # 如果缓存到磁盘
b += self.npy_files[i].stat().st_size # 更新已缓存图像的总字节数
else: # 'ram' 如果缓存到内存
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # 更新缓存的图像及其原始和调整后的尺寸
b += self.ims[i].nbytes # 更新已缓存图像的总字节数

pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})' # 更新进度条描述

pbar.close() # 关闭进度条
def cache_images_to_disk(self, i):
"""Saves an image as an *.npy file for faster loading.""" # 定义方法cache_images_to_disk,用于将图像保存为.npy文件以加快加载速度

f = self.npy_files[i] # 获取索引为i的.npy文件的路径

if not f.exists(): # 检查该.npy文件是否已存在
# 如果不存在,则读取对应的图像文件,并将其保存为.npy格式
np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False)
def check_cache_ram(self, safety_margin=0.5):
"""Check image caching requirements vs available memory.""" # 定义方法check_cache_ram,用于检查缓存图像所需的内存

b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
# 初始化用于计算缓存图像大小的变量,gb是1GB的字节数

n = min(self.ni, 30) # extrapolate from 30 random images
# 选择30个随机图像(或者整个数据集的图像数量,如果它小于30)来估计所需内存

for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
# 随机选择一个图像进行读取

ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
# 计算图像缩放比例(如果有缩放的话)

b += im.nbytes * ratio ** 2 # 计算这个图像缩放后占用的字节数,并累加到总量

mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
# 计算整个数据集缓存到RAM所需的内存大小,考虑安全余量

mem = psutil.virtual_memory() # 获取系统的虚拟内存信息

cache = mem_required < mem.available # to cache or not to cache, that is the question
# 判断是否有足够的内存进行缓存

if not cache:
# 如果内存不足以缓存所有图像,记录相关信息
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
f'with {int(safety_margin * 100)}% safety margin but only '
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")

return cache # 返回是否可以缓存的决策


set_rectangle

这段代码的目的是在使用YOLO(You Only Look Once)检测算法时,设置边界框的形状为矩形。YOLO是一种流行的对象检测算法,它在处理图像时需要考虑边界框的形状。这段代码通过调整图像的排序和批处理形状来优化YOLO的性能。通过分析数据集中图像的宽高比,来优化批处理中的图像形状。这样的优化可以帮助YOLO算法更有效地处理不同形状和尺寸的图像,提高对象检测的准确性和效率。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def set_rectangle(self):
"""Sets the shape of bounding boxes for YOLO detections as rectangles.""" # 定义set_rectangle方法,用于设置YOLO检测的边界框形状为矩形
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
# 计算每张图像所属的批次索引
nb = bi[-1] + 1 # number of batches
# 计算总批次数量
s = np.array([x.pop('shape') for x in self.labels]) # hw
# 提取每个标签中的图像尺寸(高度和宽度)
ar = s[:, 0] / s[:, 1] # aspect ratio
# 计算每个图像的宽高比
irect = ar.argsort()
# 获取宽高比排序的索引
self.im_files = [self.im_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
# 根据宽高比重新排序图像文件和标签
ar = ar[irect]
# 更新宽高比数组为排序后的顺序
# Set training image shapes
shapes = [[1, 1]] * nb
# 初始化每个批次的图像形状数组
for i in range(nb):
ari = ar[bi == i]
# 获取每个批次的宽高比
mini, maxi = ari.min(), ari.max()
# 计算每个批次的最小和最大宽高比
# 设置每个批次的图像形状
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
# 计算每个批次的实际图像形状,考虑图像大小和步长
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
# 设置每张图像的批次索引

get_image_and_label

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def get_image_and_label(self, index):
"""Get and return label information from the dataset.""" # 定义get_image_and_label方法,用于获取并返回数据集中的标签信息

label = deepcopy(self.labels[index]) # requires deepcopy()
# 使用深拷贝来复制指定索引的标签信息,以避免修改原始数据

label.pop('shape', None) # shape is for rect, remove it
# 移除标签中的'shape'键值对,因为它仅用于矩形,不需要在这里

label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
# 调用load_image方法加载图像,并获取原始和调整后的尺寸

label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
# 计算原始尺寸和调整后尺寸之间的比例,用于后续评估

if self.rect:
label['rect_shape'] = self.batch_shapes[self.batch[index]]
# 如果设置了矩形模式,则获取对应批次的图像形状

return self.update_labels_info(label)
# 调用update_labels_info方法更新标签信息,并返回结果

其他未实现的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def __len__(self):
"""Returns the length of the labels list for the dataset."""
return len(self.labels)

def update_labels_info(self, label):
"""Custom your label format here."""
return label

def build_transforms(self, hyp=None):
"""
Users can customize augmentations here.

Example:
```python
if self.augment:
# Training transforms
return Compose([])
else:
# Val transforms
return Compose([])
    """
    raise NotImplementedError

def get_labels(self):
    """
    Users can customize their own format here.

    Note:
        Ensure output is a dictionary with the following keys:
        
1
2
3
4
5
6
7
8
9
10
dict(
im_file=im_file,
shape=shape, # format: (height, width)
cls=cls,
bboxes=bboxes, # xywh
segments=segments, # xy
keypoints=keypoints, # xy
normalized=True, # or False
bbox_format="xyxy", # or xywh, ltwh
)
""" raise NotImplementedError
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

# 数据增强

### 1. `augment.py`

这个文件包含了多种数据增强方法的实现,例如随机视角(`RandomPerspective`)、Mosaic增强(`Mosaic`)、MixUp增强(`MixUp`)、随机水平或垂直翻转(`RandomFlip`)等。可以在这里调整这些增强方法的参数或添加新的增强策略。

`BaseTransform`:

- 基础变换类,提供了图像处理的基本结构。
- 方法包括应用于图像、实例(对象实体)和语义分割的变换。

`Compose`:

- 组合多个图像变换的类。
- 允许将多个变换合并成一个单一的变换流程。

`BaseMixTransform` 和其子类 `Mosaic` 和 `MixUp`:

- 提供了混合图像的变换方法,如Mosaic和MixUp。
- 这些方法通过结合多个图像来增强训练数据。

`RandomPerspective`:

- 实现随机透视和仿射变换。
- 包括旋转、平移、缩放和剪切。

`RandomHSV`:

- 对图像的HSV(色调、饱和度、明度)通道进行随机调整。

`RandomFlip`:

- 应用随机水平或垂直翻转。

`LetterBox`:

- 调整图像大小并添加填充以用于检测、实例分割和姿态估计。

`CopyPaste`:

- 实现Copy-Paste增强,将图像的一部分复制并粘贴到另一个图像上。

`Albumentations`:

- 如果安装了Albumentations库,提供额外的图像增强功能。

`Format`:

- 格式化图像注释,以便用于对象检测、实例分割和姿态估计任务

### format类

```py
class Format:
# 这个类用于格式化用于对象检测、实例分割和姿态估计任务的图像标注。
# 它标准化了图像和实例标注,以便在PyTorch DataLoader的collate_fn中使用。
def __init__(self,
bbox_format='xywh',
normalize=True,
return_mask=False,
return_keypoint=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True):
# 初始化Format类并设置各种属性
self.bbox_format = bbox_format # 设置边界框的格式
self.normalize = normalize # 设置是否对边界框进行归一化
self.return_mask = return_mask # 设置是否返回分割掩码
self.return_keypoint = return_keypoint # 设置是否返回关键点
self.mask_ratio = mask_ratio # 设置掩码的下采样比率
self.mask_overlap = mask_overlap # 设置是否允许掩码重叠
self.batch_idx = batch_idx # 设置是否保留批次索引

def __call__(self, labels):
# 格式化图像、类别、边界框和关键点以供collate_fn使用
img = labels.pop('img') # 获取并移除图像
h, w = img.shape[:2] # 获取图像的高度和宽度
cls = labels.pop('cls') # 获取并移除类别
instances = labels.pop('instances') # 获取并移除实例
instances.convert_bbox(format=self.bbox_format) # 转换边界框格式
instances.denormalize(w, h) # 反归一化边界框
nl = len(instances) # 获取实例数量

if self.return_mask:
# 如果需要返回掩码
if nl:
masks, instances, cls = self._format_segments(instances, cls, w, h)
masks = torch.from_numpy(masks) # 将掩码转换为张量
else:
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
img.shape[1] // self.mask_ratio) # 生成空掩码
labels['masks'] = masks # 添加掩码到标签中

if self.normalize:
instances.normalize(w, h) # 归一化实例
labels['img'] = self._format_img(img) # 格式化图像
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl) # 格式化类别
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) # 格式化边界框
if self.return_keypoint:
labels['keypoints'] = torch.from_numpy(instances.keypoints) # 格式化关键点
if self.batch_idx:
labels['batch_idx'] = torch.zeros(nl) # 添加批次索引
return labels

def _format_img(self, img):
# 将图像从Numpy数组格式转换为PyTorch张量
if len(img.shape) < 3:
img = np.expand_dims(img, -1) # 如果图像是单通道的,则添加一个维度
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]) # 将通道顺序从BGR转换为RGB,并调整维度顺序
img = torch.from_numpy(img) # 转换为张量
return img

def _format_segments(self, instances, cls, w, h):
# 将多边形点转换为位图掩码
segments = instances.segments
if self.mask_overlap:
# 如果允许掩码重叠
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
masks = masks[None] # 改变形状以适应张量格式
instances = instances[sorted_idx] # 根据索引排序实例
cls = cls[sorted_idx] # 根据索引排序类别
else:
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio) # 不重叠的掩码生成

return masks, instances, cls

2. dataset.py

这个文件主要处理数据集的加载和预处理。它定义了如何应用augment.py中定义的增强策略到数据集上。例如,YOLODataset 类中的 build_transforms 方法用于构建并应用增强转换。您可以在这里根据需要调整增强的应用方式或顺序。

实现了类,继承base.py中的BaseDataset类