mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Update to MiniCPM-o 2.6
This commit is contained in:
@@ -7,7 +7,6 @@ import re
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from decord import VideoReader, cpu # pip install decord
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -21,26 +20,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||
|
||||
MAX_NUM_FRAMES=64
|
||||
def encode_video(video_path, max_num_frames=64):
|
||||
max_num_frames = min(max_num_frames, MAX_NUM_FRAMES)
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > max_num_frames:
|
||||
if max_num_frames==1:
|
||||
frame_idx = [frame_idx[len(frame_idx)//2]]
|
||||
else:
|
||||
frame_idx = uniform_sample(frame_idx, max_num_frames)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
return frames
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
@@ -55,8 +34,6 @@ class SupervisedDataset(Dataset):
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
video_max_slice_nums=2,
|
||||
max_num_frames=1,
|
||||
):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
self.raw_data = raw_data
|
||||
@@ -68,58 +45,17 @@ class SupervisedDataset(Dataset):
|
||||
self.query_nums=query_nums
|
||||
self.batch_vision = batch_vision
|
||||
self.max_length = max_length
|
||||
# video config
|
||||
self.video_slice_config = copy.deepcopy(slice_config)
|
||||
self.video_slice_config['max_slice_nums'] = video_max_slice_nums
|
||||
self.max_num_frames = max_num_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_data)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
try:
|
||||
# default: sft image
|
||||
use_image_id = True
|
||||
slice_config = self.slice_config
|
||||
if "image" in self.raw_data[i]:
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
||||
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
||||
elif "video" in self.raw_data[i]:
|
||||
if isinstance(self.raw_data[i]["video"], str):
|
||||
frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames)
|
||||
image_names = []
|
||||
images_dict = {}
|
||||
for j, frame in enumerate(frames):
|
||||
image_name = "<image_{:02d}>".format(j)
|
||||
images_dict[image_name] = frame
|
||||
image_names.append(image_name)
|
||||
for j in range(len(self.raw_data[i]["conversations"])):
|
||||
content = self.raw_data[i]["conversations"][j]['content']
|
||||
self.raw_data[i]["conversations"][j]['content'] = content.replace("<video>", "".join(image_names))
|
||||
elif isinstance(self.raw_data[i]["video"], Dict):
|
||||
videos = self.raw_data[i]["video"]
|
||||
images_dict = {}
|
||||
video_names = {}
|
||||
cnt = 0
|
||||
for video_name in videos:
|
||||
video_id = video_name.split("_")[-1].strip(">")
|
||||
video = videos[video_name]
|
||||
frames = encode_video(video, max_num_frames=self.max_num_frames)
|
||||
image_names = []
|
||||
for j, frame in enumerate(frames):
|
||||
image_name = "<image_{:02d}>".format(cnt)
|
||||
cnt += 1
|
||||
images_dict[image_name] = frame
|
||||
image_names.append(image_name)
|
||||
for j in range(len(self.raw_data[i]["conversations"])):
|
||||
content = self.raw_data[i]["conversations"][j]['content']
|
||||
self.raw_data[i]["conversations"][j]['content'] = content.replace(video_name, "".join(image_names))
|
||||
# video: modify config
|
||||
slice_config = self.video_slice_config
|
||||
use_image_id = False
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
||||
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
||||
|
||||
ret = preprocess(
|
||||
images_dict,
|
||||
@@ -131,8 +67,7 @@ class SupervisedDataset(Dataset):
|
||||
llm_type=self.llm_type,
|
||||
patch_size=self.patch_size,
|
||||
batch_vision=self.batch_vision,
|
||||
max_length=self.max_length,
|
||||
use_image_id=use_image_id
|
||||
max_length=self.max_length
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
@@ -197,7 +132,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
||||
input_ids, context, raw_msg = conversation_to_ids_llama3(
|
||||
conversation, tokenizer
|
||||
)
|
||||
elif llm_type == "qwen2":
|
||||
elif llm_type == "qwen":
|
||||
input_ids, context, raw_msg = conversation_to_ids_qwen2(
|
||||
conversation, tokenizer
|
||||
)
|
||||
@@ -383,7 +318,6 @@ def preprocess(
|
||||
patch_size=14,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
use_image_id=True
|
||||
):
|
||||
"""
|
||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||
@@ -402,9 +336,9 @@ def preprocess(
|
||||
)
|
||||
new_schema = False
|
||||
use_image_id = False
|
||||
if llm_type=='qwen2':
|
||||
if llm_type=='qwen':
|
||||
new_schema = True
|
||||
use_image_id = use_image_id
|
||||
use_image_id = True
|
||||
image_placeholder_dict = {}
|
||||
images = []
|
||||
image_id_cnt = 0
|
||||
|
||||
@@ -14,7 +14,7 @@ from accelerate.utils import DistributedType
|
||||
from deepspeed import zero
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.integrations import deepspeed
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
@@ -53,8 +53,6 @@ class TrainingArguments(transformers.TrainingArguments):
|
||||
llm_type: str = field(default="minicpm")
|
||||
use_lora: Optional[bool] = field(default=False)
|
||||
max_slice_nums: Optional[int] = field(default=9)
|
||||
video_max_slice_nums: Optional[int] = field(default=2)
|
||||
max_num_frames: Optional[int] = field(default=1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,8 +92,6 @@ def make_supervised_data_module(
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
video_max_slice_nums=2,
|
||||
max_num_frames=1,
|
||||
) -> Dict:
|
||||
"""Make dataset and collator for supervised fine-tuning."""
|
||||
dataset_cls = SupervisedDataset
|
||||
@@ -113,8 +109,6 @@ def make_supervised_data_module(
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
max_length=max_length,
|
||||
video_max_slice_nums=video_max_slice_nums,
|
||||
max_num_frames=max_num_frames,
|
||||
)
|
||||
|
||||
if data_args.eval_data_path:
|
||||
@@ -129,8 +123,6 @@ def make_supervised_data_module(
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
max_length=max_length,
|
||||
video_max_slice_nums=video_max_slice_nums,
|
||||
max_num_frames=max_num_frames,
|
||||
)
|
||||
else:
|
||||
eval_dataset = None
|
||||
@@ -210,10 +202,10 @@ def train():
|
||||
trust_remote_code=True,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map=device_map,
|
||||
init_vision=True,
|
||||
init_audio=False,
|
||||
init_tts=False,
|
||||
)
|
||||
model.__class__.register_for_auto_class()
|
||||
|
||||
model.processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
@@ -287,8 +279,6 @@ def train():
|
||||
query_nums=model.config.query_num,
|
||||
batch_vision=batch_vision,
|
||||
max_length=training_args.model_max_length,
|
||||
video_max_slice_nums=training_args.video_max_slice_nums,
|
||||
max_num_frames=training_args.max_num_frames,
|
||||
)
|
||||
|
||||
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
||||
|
||||
@@ -5,14 +5,17 @@ NNODES=1
|
||||
NODE_RANK=0
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
|
||||
MODEL="openbmb/MiniCPM-V-2_6"
|
||||
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
|
||||
|
||||
MODEL="openbmb/MiniCPM-o-2_6"
|
||||
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||
# See the section for finetuning in README for more information.
|
||||
DATA="path/to/trainging_data"
|
||||
EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||
|
||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
||||
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
||||
LLM_TYPE="qwen"
|
||||
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
||||
|
||||
|
||||
@@ -38,7 +41,7 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--tune_vision true \
|
||||
--tune_llm true \
|
||||
--tune_llm false \
|
||||
--model_max_length $MODEL_MAX_Length \
|
||||
--max_slice_nums 9 \
|
||||
--max_steps 10000 \
|
||||
@@ -60,5 +63,5 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--gradient_checkpointing true \
|
||||
--deepspeed ds_config_zero2.json \
|
||||
--deepspeed ds_config_zero3.json \
|
||||
--report_to "tensorboard"
|
||||
|
||||
@@ -5,16 +5,16 @@ NNODES=1
|
||||
NODE_RANK=0
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
|
||||
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
|
||||
|
||||
MODEL="openbmb/MiniCPM-o-2_6"
|
||||
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6
|
||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||
# See the section for finetuning in README for more information.
|
||||
DATA="path/to/trainging_data"
|
||||
EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="qwen2"
|
||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
||||
|
||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
||||
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
||||
LLM_TYPE="qwen"
|
||||
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
@@ -24,6 +24,7 @@ DISTRIBUTED_ARGS="
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--model_name_or_path $MODEL \
|
||||
--llm_type $LLM_TYPE \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# MiniCPM-V Finetuning
|
||||
|
||||
|
||||
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-2_6**, **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
||||
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-o-2_6**, **MiniCPM-V-2_6**, **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
||||
|
||||
### Data preparation
|
||||
|
||||
@@ -20,30 +20,30 @@ If your input consists of a single image, you can use a single placeholder **\<i
|
||||
[
|
||||
{
|
||||
"id": "0",
|
||||
"image": "path/to/image_0.jpg",
|
||||
"image": 'path/to/image_0.jpg',
|
||||
"conversations": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<image>\nHow many desserts are on the white plate?"
|
||||
'role': 'user',
|
||||
'content': '<image>\nHow many desserts are on the white plate?'
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "There are three desserts on the white plate."
|
||||
'role': 'assistant',
|
||||
'content': 'There are three desserts on the white plate.'
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What type of desserts are they?"
|
||||
'role': 'user',
|
||||
'content': 'What type of desserts are they?'
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them."
|
||||
'role': 'assistant',
|
||||
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the setting of the image?"},
|
||||
'role': 'user',
|
||||
'content': 'What is the setting of the image?'},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The image is set on a table top with a plate containing the three desserts."
|
||||
'role': 'assistant',
|
||||
'content': 'The image is set on a table top with a plate containing the three desserts.'
|
||||
},
|
||||
]
|
||||
},
|
||||
@@ -91,81 +91,16 @@ If the total token count exceeds `max_length`, truncation will be applied. For m
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Single Video Example
|
||||
If your input consists of a single video, you can use a single placeholder **\<video\>** to indicate where the video should be inserted in the conversation.
|
||||
<details>
|
||||
<summary>
|
||||
<b>Single video example (vl_finetune_video.json) with 1 samples.</b>
|
||||
</summary>
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"id": "0",
|
||||
"video": "path/to/video_0.mp4",
|
||||
"conversations": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<video>\nHow many desserts are on the white plate?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "There are three desserts on the white plate."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Multiple Videos Example
|
||||
For inputs containing multiple videos, utilize a dictionary where each key represents a unique placeholder (e.g., **\<video_00\>**, **\<video_01\**) with the corresponding video path as its value. These placeholders can then be used within the conversation to seamlessly insert videos at specific positions.
|
||||
|
||||
Additionally, to optimize resource management, especially when dealing with large batches of videos during training or inference, consider reducing `video_max_slice_nums` and `max_num_frames`. To minimize the number of tokens used per video, you can set `video_max_slice_nums=1` and `max_num_frames=1`, resulting in a single video being represented by 64 tokens.
|
||||
|
||||
If the total token count exceeds `max_length`, truncation will be applied. For multi-video supervised fine-tuning (SFT), it's recommended to set `MODEL_MAX_LENGTH=4096` in your script for better performance.
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>Multiple videos example (vl_finetune_data.json) with 1 samples.</b>
|
||||
</summary>
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"id": "0",
|
||||
"video": {
|
||||
"<video_00>": "path/to/video_0.mp4",
|
||||
"<video_01>": "path/to/video_1.avi",
|
||||
"<video_02>": "path/to/video_2.mp4",
|
||||
"<video_03>": "path/to/video_3.avi"
|
||||
},
|
||||
"conversations": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How to create such text-only videos using CapCut?\n<video_00>\n<image_01>\n<video_01>\n<video_02>\n"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "To create a text-only video as shown in the videos, follow these steps in CapCut..."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
### Full-parameter finetuning
|
||||
|
||||
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
|
||||
|
||||
```shell
|
||||
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||
MODEL="MiniCPM-o-2_6" # or "openbmb/MiniCPM-V-2_6", openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||
DATA="path/to/trainging_data" # json file
|
||||
EVAL_DATA="path/to/test_data" # json file
|
||||
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||
LLM_TYPE="qwen" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3",
|
||||
# if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen
|
||||
```
|
||||
|
||||
To launch your training, run the following script:
|
||||
@@ -188,7 +123,7 @@ After training, you could load the model with the path to the adapter. We advise
|
||||
```
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModel
|
||||
model_type= "openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5 , openbmb/MiniCPM-V-2
|
||||
model_type= ""openbmb/MiniCPM-o-2_6" or # openbmb/MiniCPM-V-2_6", openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||
path_to_adapter="path_to_your_fine_tuned_checkpoint"
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
|
||||
44
finetune/requirements.txt
Normal file
44
finetune/requirements.txt
Normal file
@@ -0,0 +1,44 @@
|
||||
packaging==23.2
|
||||
addict==2.4.0
|
||||
editdistance==0.6.2
|
||||
einops==0.7.0
|
||||
fairscale==0.4.0
|
||||
jsonlines==4.0.0
|
||||
markdown2==2.4.10
|
||||
matplotlib==3.7.4
|
||||
more_itertools==10.1.0
|
||||
nltk==3.8.1
|
||||
numpy==1.24.4
|
||||
opencv_python_headless==4.5.5.64
|
||||
openpyxl==3.1.2
|
||||
Pillow==10.1.0
|
||||
sacrebleu==2.3.2
|
||||
seaborn==0.13.0
|
||||
shortuuid==1.0.11
|
||||
spacy==3.7.2
|
||||
torch==2.2.0
|
||||
torchaudio==2.2.0
|
||||
torchvision==0.17.0
|
||||
timm==0.9.10
|
||||
tqdm==4.66.1
|
||||
protobuf==4.25.0
|
||||
typing_extensions==4.8.0
|
||||
uvicorn==0.24.0.post1
|
||||
#xformers==0.0.22.post7
|
||||
#flash_attn==2.3.4
|
||||
sentencepiece==0.1.99
|
||||
accelerate==0.30.1
|
||||
socksio==1.0.0
|
||||
gradio==4.41.0
|
||||
gradio_client
|
||||
http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
|
||||
decord
|
||||
aiosignal
|
||||
tensorborad
|
||||
deepspeed==0.12.3
|
||||
transformers==4.44.2
|
||||
librosa==0.9.0
|
||||
soundfile==0.12.1
|
||||
vector-quantize-pytorch==1.18.5
|
||||
vocos==0.1.0
|
||||
moviepy
|
||||
@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int=None) -> torch.Tensor:
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
@@ -245,9 +245,6 @@ class CPMTrainer(Trainer):
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
if self.model.processor is not None:
|
||||
self.model.processor.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
Reference in New Issue
Block a user