-
Notifications
You must be signed in to change notification settings - Fork 282
Expand file tree
/
Copy pathDatasetUtil.java
More file actions
484 lines (389 loc) · 15.2 KB
/
DatasetUtil.java
File metadata and controls
484 lines (389 loc) · 15.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
package apijson;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.*;
public class DatasetUtil {
public static void main(String[] args) {
try {
// --- 调用示例 ---
// 示例1:只生成检测数据集
System.out.println("Generating DETECTION dataset...");
Set<TaskType> detectionTasks = new HashSet<>(Collections.singletonList(TaskType.DETECTION));
generate("./output/detection_dataset", detectionTasks);
// 示例2:生成分割数据集
System.out.println("\nGenerating SEGMENTATION dataset...");
Set<TaskType> segTasks = new HashSet<>(Collections.singletonList(TaskType.SEGMENTATION));
generate("./output/segmentation_dataset", segTasks);
// 示例3:生成姿态关键点数据集
System.out.println("\nGenerating POSE_KEYPOINTS dataset...");
Set<TaskType> keypointTasks = new HashSet<>(Collections.singletonList(TaskType.POSE_KEYPOINTS));
generate("./output/keypoints_dataset", keypointTasks);
// 示例4:生成OCR数据集
System.out.println("\nGenerating OCR dataset...");
Set<TaskType> ocrTasks = new HashSet<>(Collections.singletonList(TaskType.OCR));
generate("./output/ocr_dataset", ocrTasks);
// 示例5:在一个JSON中同时包含检测和关键点标注
System.out.println("\nGenerating combined DETECTION and KEYPOINTS dataset...");
Set<TaskType> combinedTasks = new HashSet<>(Arrays.asList(TaskType.DETECTION, TaskType.POSE_KEYPOINTS));
generate("./output/combined_dataset", combinedTasks);
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 定义支持的任务类型
*/
public enum TaskType {
CLASSIFICATION,
DETECTION,
SEGMENTATION,
POSE_KEYPOINTS,
FACE_KEYPOINTS,
ROTATED_DETECTION,
OCR
}
/**
* 数据集构建器
*/
public static class DatasetBuilder {
private final CocoDataset dataset;
private int imageIdCounter = 1;
private int annotationIdCounter = 1;
public DatasetBuilder() {
this.dataset = new CocoDataset();
this.dataset.setInfo(new HashMap<>());
this.dataset.setLicenses(new ArrayList<>());
this.dataset.setImages(new ArrayList<>());
this.dataset.setCategories(new ArrayList<>());
this.dataset.setAnnotations(new ArrayList<>());
}
public DatasetBuilder withInfo(String description, String version, String year) {
Map<String, String> info = new HashMap<>();
info.put("description", description);
info.put("version", version);
info.put("year", year);
info.put("date_created", new SimpleDateFormat("yyyy-MM-dd").format(new Date()));
this.dataset.setInfo(info);
return this;
}
public DatasetBuilder withCategory(int id, String name, String supercategory) {
Category cat = new Category();
cat.setId(id);
cat.setName(name);
cat.setSupercategory(supercategory);
this.dataset.getCategories().add(cat);
return this;
}
// 可为关键点任务添加专门的 category 方法
public DatasetBuilder withKeypointCategory(int id, String name, String supercategory, List<String> keypoints, List<List<Integer>> skeleton) {
Category cat = new Category();
cat.setId(id);
cat.setName(name);
cat.setSupercategory(supercategory);
cat.setKeypoints(keypoints);
cat.setSkeleton(skeleton);
this.dataset.getCategories().add(cat);
return this;
}
public DatasetBuilder addImage(String fileName, int width, int height) {
ImageInfo img = new ImageInfo();
img.setId(imageIdCounter++);
img.setFile_name(fileName);
img.setWidth(width);
img.setHeight(height);
this.dataset.getImages().add(img);
return this;
}
public DatasetBuilder addAnnotation(Annotation annotation) {
// 确保设置了唯一的 ID
annotation.setId(annotationIdCounter++);
this.dataset.getAnnotations().add(annotation);
return this;
}
public CocoDataset build() {
return this.dataset;
}
}
/**
* 将 COCO 数据集对象写入 JSON 文件
* @param dataset COCO 数据集对象
* @param outputPath 输出文件路径 (e.g., /path/to/annotations/instances_train2017.json)
*/
public static void writeToFile(CocoDataset dataset, String outputPath) throws IOException {
Path parentDir = Paths.get(outputPath).getParent();
if (parentDir != null && !Files.exists(parentDir)) {
Files.createDirectories(parentDir);
}
Gson gson = new GsonBuilder().setPrettyPrinting().create();
try (FileWriter writer = new FileWriter(outputPath)) {
gson.toJson(dataset, writer);
}
System.out.println("Successfully generated COCO JSON file at: " + outputPath);
}
/**
* 主生成方法(示例)
* 实际使用中,你需要从你的数据源(如XML, CSV)读取数据来填充这些 Annotation
*/
public static void generate(String outputDir, Set<TaskType> tasks) throws IOException {
// --- 1. 初始化构建器和通用信息 ---
DatasetBuilder builder = new DatasetBuilder()
.withInfo("My Custom Dataset", "1.0", "2025")
.withCategory(1, "person", "person")
.withCategory(2, "car", "vehicle")
.withCategory(3, "dog", "animal")
.withKeypointCategory(1, "person", "person",
Arrays.asList("nose", "left_eye", "right_eye"), // 简化版关键点
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(1, 3))
);
// --- 2. 添加图片信息 ---
// 假设我们有两张图片
builder.addImage("00001.jpg", 640, 480); // image_id 将是 1
builder.addImage("00002.jpg", 800, 600); // image_id 将是 2
// --- 3. 根据任务类型添加标注 (核心部分) ---
// 这是示例数据,你需要替换成你自己的真实数据加载逻辑
// 为 image 1 添加标注
if (tasks.contains(TaskType.DETECTION) || tasks.contains(TaskType.SEGMENTATION) || tasks.contains(TaskType.ROTATED_DETECTION)) {
DetectionAnnotation detAnn = new DetectionAnnotation();
detAnn.setImage_id(1);
detAnn.setCategory_id(3); // dog
detAnn.setBbox(Arrays.asList(100.0, 50.0, 80.0, 120.0));
detAnn.setArea(80.0 * 120.0);
if (tasks.contains(TaskType.SEGMENTATION)) {
detAnn.setSegmentation(Arrays.asList(
Arrays.asList(100.0, 50.0, 180.0, 50.0, 180.0, 170.0, 100.0, 170.0)
));
}
if(tasks.contains(TaskType.ROTATED_DETECTION)){
// 旋转检测通常用四点表示,这里也放在segmentation里
detAnn.setSegmentation(Arrays.asList(
Arrays.asList(110.0, 55.0, 175.0, 60.0, 170.0, 165.0, 105.0, 160.0)
));
}
builder.addAnnotation(detAnn);
}
if (tasks.contains(TaskType.POSE_KEYPOINTS)) {
KeypointAnnotation kpAnn = new KeypointAnnotation();
kpAnn.setImage_id(1);
kpAnn.setCategory_id(1); // person
kpAnn.setBbox(Arrays.asList(200.0, 100.0, 50.0, 150.0));
kpAnn.setArea(50.0 * 150.0);
kpAnn.setNum_keypoints(3);
kpAnn.setKeypoints(Arrays.asList(225.0, 110.0, 2.0, 215.0, 105.0, 2.0, 235.0, 105.0, 2.0)); // [x,y,v, x,y,v, ...]
builder.addAnnotation(kpAnn);
}
// 为 image 2 添加标注
if (tasks.contains(TaskType.OCR)) {
OcrAnnotation ocrAnn = new OcrAnnotation();
ocrAnn.setImage_id(2);
ocrAnn.setCategory_id(2); // car, 假设车牌是OCR目标
ocrAnn.setBbox(Arrays.asList(300.0, 400.0, 120.0, 30.0));
ocrAnn.setArea(120.0 * 30.0);
// OCR通常用四边形表示位置
ocrAnn.setSegmentation(Arrays.asList(
Arrays.asList(300.0, 400.0, 420.0, 400.0, 420.0, 430.0, 300.0, 430.0)
));
Map<String, Object> attrs = new HashMap<>();
attrs.put("transcription", "AB-1234");
attrs.put("legible", true);
ocrAnn.setAttributes(attrs);
builder.addAnnotation(ocrAnn);
}
// --- 4. 构建并写入文件 ---
CocoDataset cocoDataset = builder.build();
// 为不同任务生成不同的文件名
String taskName = tasks.iterator().next().toString().toLowerCase(); // 用第一个任务命名
String outputJsonPath = Paths.get(outputDir, "annotations", "instances_" + taskName + ".json").toString();
writeToFile(cocoDataset, outputJsonPath);
// 之后,你需要将图片文件(00001.jpg, 00002.jpg)复制到指定的图片目录下,
// 例如 outputDir/images/
}
public static class ImageInfo {
private int id;
private String file_name;
private int width;
private int height;
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public String getFile_name() {
return file_name;
}
public void setFile_name(String file_name) {
this.file_name = file_name;
}
public int getWidth() {
return width;
}
public void setWidth(int width) {
this.width = width;
}
public int getHeight() {
return height;
}
public void setHeight(int height) {
this.height = height;
}
}
public static class Category {
private int id;
private String name;
private String supercategory;
// For Keypoints
private List<String> keypoints;
private List<List<Integer>> skeleton;
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getSupercategory() {
return supercategory;
}
public void setSupercategory(String supercategory) {
this.supercategory = supercategory;
}
public List<String> getKeypoints() {
return keypoints;
}
public void setKeypoints(List<String> keypoints) {
this.keypoints = keypoints;
}
public List<List<Integer>> getSkeleton() {
return skeleton;
}
public void setSkeleton(List<List<Integer>> skeleton) {
this.skeleton = skeleton;
}
}
public static class Annotation {
private int id;
private int image_id;
private int category_id;
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public int getImage_id() {
return image_id;
}
public void setImage_id(int image_id) {
this.image_id = image_id;
}
public int getCategory_id() {
return category_id;
}
public void setCategory_id(int category_id) {
this.category_id = category_id;
}
}
public static class DetectionAnnotation extends Annotation {
private List<Double> bbox; // [x, y, width, height]
private double area;
private int iscrowd = 0;
private List<List<Double>> segmentation; // for segmentation & rotated box
public List<Double> getBbox() {
return bbox;
}
public void setBbox(List<Double> bbox) {
this.bbox = bbox;
}
public double getArea() {
return area;
}
public void setArea(double area) {
this.area = area;
}
public int getIscrowd() {
return iscrowd;
}
public void setIscrowd(int iscrowd) {
this.iscrowd = iscrowd;
}
public List<List<Double>> getSegmentation() {
return segmentation;
}
public void setSegmentation(List<List<Double>> segmentation) {
this.segmentation = segmentation;
}
}
public static class KeypointAnnotation extends DetectionAnnotation {
private int num_keypoints;
private List<Double> keypoints; // [x1, y1, v1, x2, y2, v2, ...]
public int getNum_keypoints() {
return num_keypoints;
}
public void setNum_keypoints(int num_keypoints) {
this.num_keypoints = num_keypoints;
}
public List<Double> getKeypoints() {
return keypoints;
}
public void setKeypoints(List<Double> keypoints) {
this.keypoints = keypoints;
}
}
public static class OcrAnnotation extends DetectionAnnotation {
private Map<String, Object> attributes; // {"transcription": "TEXT", "legible": true}
public Map<String, Object> getAttributes() {
return attributes;
}
public void setAttributes(Map<String, Object> attributes) {
this.attributes = attributes;
}
}
public static class CocoDataset {
private Map<String, String> info;
private List<Map<String, String>> licenses;
private List<ImageInfo> images;
private List<Category> categories;
private List<Annotation> annotations; // 使用基类,实现多态
public Map<String, String> getInfo() {
return info;
}
public void setInfo(Map<String, String> info) {
this.info = info;
}
public List<Map<String, String>> getLicenses() {
return licenses;
}
public void setLicenses(List<Map<String, String>> licenses) {
this.licenses = licenses;
}
public List<ImageInfo> getImages() {
return images;
}
public void setImages(List<ImageInfo> images) {
this.images = images;
}
public List<Category> getCategories() {
return categories;
}
public void setCategories(List<Category> categories) {
this.categories = categories;
}
public List<Annotation> getAnnotations() {
return annotations;
}
public void setAnnotations(List<Annotation> annotations) {
this.annotations = annotations;
}
}
}