import gradio as gr import torch from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer # load IL models bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart') bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left') bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True) bert_model = AutoModel.from_pretrained('webshop/il-choice-bert-image_0', trust_remote_code=True) def process_str(s): s = s.lower().replace('"', '').replace("'", "").strip() s = s.replace('[sep]', '[SEP]') return s def process_goal(state): state = state.lower().replace('"', '').replace("'", "") state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip() if ', and price lower than' in state: state = state.split(', and price lower than')[0] return state def data_collator(batch): state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], [] for sample in batch: state_input_ids.append(sample['state_input_ids']) state_attention_mask.append(sample['state_attention_mask']) action_input_ids.extend(sample['action_input_ids']) action_attention_mask.extend(sample['action_attention_mask']) sizes.append(sample['sizes']) labels.append(sample['labels']) images.append(sample['images']) max_state_len = max(sum(x) for x in state_attention_mask) max_action_len = max(sum(x) for x in action_attention_mask) return { 'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len], 'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len], 'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len], 'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len], 'sizes': torch.tensor(sizes), 'images': torch.tensor(images), 'labels': torch.tensor(labels), } def bart_predict(input): input_ids = bart_tokenizer(input)['input_ids'] input_ids = torch.tensor(input_ids).unsqueeze(0) output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5) return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0] def bert_predict(obs, info, softmax=True): valid_acts = info['valid'] assert valid_acts[0].startswith('click[') state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length') action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length') batch = { 'state_input_ids': state_encodings['input_ids'], 'state_attention_mask': state_encodings['attention_mask'], 'action_input_ids': action_encodings['input_ids'], 'action_attention_mask': action_encodings['attention_mask'], 'sizes': len(valid_acts), 'images': info['image_feat'].tolist(), 'labels': 0 } batch = data_collator([batch]) outputs = bert_model(**batch) if softmax: idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item() else: idx = outputs.logits[0].argmax(0).item() return valid_acts[idx] def predict(obs, info): """ Given WebShop environment observation and info, predict an action. """ valid_acts = info['valid'] if valid_acts[0].startswith('click['): return bert_predict(obs, info) else: return bart_predict(process_goal(obs)) def run_episode(goal): """ Interact with amazon to find a product given input goal. Input: text goal Output: a url of found item on amazon. """ return bart_predict(goal) # TODO: implement run_episode gr.Interface(fn=run_episode, inputs=gr.inputs.Textbox( lines=7, label="Input Text"), outputs="text").launch(inline=False)