--- license: apache-2.0 language: - en - zh tags: - medical - text-generation-inference - image + text --- import torch from datasets import load_dataset, load_from_disk from peft import LoraConfig, get_peft_model from PIL import Image from transformers import AutoModelForVision2Seq, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig import torchvision.transforms as transforms device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "med_tongue_vision-zh_V0.1" Here we skip some special modules that can't be quantized properly bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) Simply take-off the quantization_config arg if you want to load the original model model = AutoModelForVision2Seq.from_pretrained( model_id, quantization_config=bnb_config, torch_dtype=torch.float16, device_map="auto") print(model) processor = AutoProcessor.from_pretrained(model_id) messages = [ { "role": "user", "content": [ {"type": "text", "text": sys_prompt}, {"type": "image"}, {"type": "text", "text": "告诉我图片中的舌象指标有哪些"} ] } ] text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True) generated_ids = model.generate(**inputs, max_new_tokens=512) generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True) ===== 结果 ==================== 客户提问: 请描述这张图片中的舌象细分类别。 客户舌图: 舌图15类标签: 绛_舌胖_有齿痕_有裂纹_有点刺_无瘀斑_无瘀点_无老嫩_无歪斜_黄_厚苔_无腐苔_有腻苔_润_无剥脱 大模型识别结果: ['淡红_舌胖_有齿痕_有裂纹_有点刺_无瘀斑_无瘀点_无老嫩_无歪斜_白_厚苔_无腐苔_有腻苔_润_无剥脱'] 这里呈现的模型,是一个精度低的模型的视觉多模态模型的能力演示, 基于更大规模的舌诊数据集的高精度模型需要联系:bxp2028@gmail.com