Skip to content

Commit a3a52d0

Browse files
Add new dataset (#1227)
1 parent ab51afd commit a3a52d0

File tree

2 files changed

+81
-7
lines changed

2 files changed

+81
-7
lines changed

swift/llm/utils/dataset.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from swift.utils import get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, transform_jsonl_to_df
2323
from swift.utils.torch_utils import _find_local_mac
24-
from .media import MediaCache
24+
from .media import MediaCache, MediaTag
2525
from .preprocess import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
2626
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
2727
TextGenerationPreprocessor, preprocess_sharegpt)
@@ -162,6 +162,8 @@ class DatasetName:
162162
midefics = 'midefics'
163163
gqa = 'gqa'
164164
text_caps = 'text-caps'
165+
refcoco_unofficial_caption = 'refcoco-unofficial-caption'
166+
refcoco_unofficial_grounding = 'refcoco-unofficial-grounding'
165167
a_okvqa = 'a-okvqa'
166168
okvqa = 'okvqa'
167169
ocr_vqa = 'ocr-vqa'
@@ -1112,6 +1114,79 @@ def preprocess(row):
11121114
load_from_cache_file=False).filter(lambda row: row.get('response')).rename_columns({'image': 'images'})
11131115

11141116

1117+
def preprocess_refcoco_unofficial_caption(dataset):
1118+
1119+
cache_dir = MediaCache.download(
1120+
'https://www.modelscope.cn/api/v1/datasets/we_dont_produce_water/'
1121+
'coco_res/repo?Revision=master&FilePath=coco_2014.zip', 'coco2014')
1122+
1123+
def preprocess(row):
1124+
caption = row['captions'][0]
1125+
bbox = row['bbox']
1126+
image_path = os.path.join(cache_dir, row['image_path'].replace('coco/train2014', 'train2014'))
1127+
media_tag = MediaTag(media_type='image', task_type='grounding_caption')
1128+
for i in range(len(bbox)):
1129+
bbox[i] = round(float(bbox[i]))
1130+
res = {}
1131+
1132+
objects = [[caption, bbox]]
1133+
media_tag(res, [image_path])
1134+
res['images'] = [image_path]
1135+
res['objects'] = json.dumps(objects)
1136+
if not os.path.exists(image_path):
1137+
res['response'] = ''
1138+
return res
1139+
1140+
return dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row.get('response'))
1141+
1142+
1143+
register_dataset(
1144+
DatasetName.refcoco_unofficial,
1145+
'swift/refcoco', [],
1146+
preprocess_func=preprocess_refcoco_unofficial_caption,
1147+
get_function=get_dataset_from_repo,
1148+
split=['train', 'validation'],
1149+
hf_dataset_id='jxu124/refcoco',
1150+
huge_dataset=True,
1151+
tags=['multi-modal', 'en', 'caption'])
1152+
1153+
1154+
def preprocess_refcoco_unofficial_grounding(dataset):
1155+
1156+
cache_dir = MediaCache.download(
1157+
'https://www.modelscope.cn/api/v1/datasets/we_dont_produce_water/'
1158+
'coco_res/repo?Revision=master&FilePath=coco_2014.zip', 'coco2014')
1159+
1160+
def preprocess(row):
1161+
caption = row['captions'][0]
1162+
bbox = row['bbox']
1163+
image_path = os.path.join(cache_dir, row['image_path'].replace('coco/train2014', 'train2014'))
1164+
media_tag = MediaTag(media_type='image', task_type='ref_grounding')
1165+
for i in range(len(bbox)):
1166+
bbox[i] = round(float(bbox[i]))
1167+
res = {}
1168+
1169+
objects = [[caption, bbox]]
1170+
media_tag(res, [image_path])
1171+
res['images'] = [image_path]
1172+
res['objects'] = json.dumps(objects)
1173+
if not os.path.exists(image_path):
1174+
res['response'] = ''
1175+
return res
1176+
1177+
return dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row.get('response'))
1178+
1179+
1180+
register_dataset(
1181+
DatasetName.refcoco_unofficial_grounding,
1182+
'swift/refcoco', [],
1183+
preprocess_func=preprocess_refcoco_unofficial_grounding,
1184+
get_function=get_dataset_from_repo,
1185+
split=['train', 'validation'],
1186+
hf_dataset_id='jxu124/refcoco',
1187+
huge_dataset=True,
1188+
tags=['multi-modal', 'en', 'grounding'])
1189+
11151190
register_dataset(
11161191
DatasetName.text_caps,
11171192
'swift/TextCaps', [],

swift/llm/utils/media.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class MediaTag:
2424
('<bbox>', '<ref-object>'),
2525
('The object at position <bbox>', '<ref-object>'),
2626
('This <bbox> is', '<ref-object>'),
27-
('What is the thing at <bbox>', '<ref-object>'),
27+
('What is the object at <bbox>', '<ref-object>'),
2828
('Describe <bbox>', '<ref-object>'),
2929
('<bbox> is', '<ref-object>'),
3030
('The bounding box coordinate <bbox> contains', '<ref-object>'),
@@ -62,14 +62,13 @@ def __init__(self,
6262
self.task_type = task_type
6363
self.media_tag = media_tag or '<unused_tag>'
6464

65-
def __call__(self, d: Dict[str, Any], medias: Union[tuple, list], objects: List = None) -> None:
65+
def __call__(self, d: Dict[str, Any], medias: Union[tuple, list]) -> None:
6666
"""Format the query/response/history with medias
6767
6868
Args:
6969
d: A dict contains history/query/response
7070
medias: A list of medias(one round, multiple medias),
7171
a single media(one round, one media), or a tuple of media list(multiple rounds)
72-
objects: A list of object-bbox pairs(one round), or a tuple of object-bbox lists(multiple rounds)
7372
"""
7473
if not self.media_type:
7574
return
@@ -83,7 +82,8 @@ def __call__(self, d: Dict[str, Any], medias: Union[tuple, list], objects: List
8382
pass
8483
elif self.task_type in ('ref_grounding', 'grounding_caption'):
8584
lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2])
86-
query, response = np.random.choice(self.task_prompts[self.task_type][lang])
85+
prompts = self.task_prompts[self.task_type][lang]
86+
query, response = prompts[np.random.choice(range(len(prompts)))]
8787
elif self.task_type == 'ocr':
8888
raise NotImplementedError
8989
else:
@@ -101,8 +101,7 @@ def __call__(self, d: Dict[str, Any], medias: Union[tuple, list], objects: List
101101
if 'history' in d:
102102
d['history'] = history
103103
d['query'] = query
104-
if 'response' in d:
105-
d['response'] = response
104+
d['response'] = response
106105

107106

108107
class MediaCache:

0 commit comments

Comments
 (0)