{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "58ff91ca-ce92-43d0-ae8b-4e9e89e193f6", "metadata": { "tags": [] }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import set_seed, AutoModelForSeq2SeqLM, AutoTokenizer\n", "from peft import get_peft_model, MultitaskPromptTuningConfig, TaskType, MultitaskPromptTuningInit\n", "\n", "set_seed(42)\n", "\n", "model_name = \"google/flan-t5-base\"\n", "\n", "peft_config = MultitaskPromptTuningConfig(\n", " tokenizer_name_or_path=model_name,\n", " num_tasks=2,\n", " task_type=TaskType.SEQ_2_SEQ_LM,\n", " prompt_tuning_init=MultitaskPromptTuningInit.TEXT,\n", " num_virtual_tokens=50,\n", " num_transformer_submodules=1,\n", " prompt_tuning_init_text=\"classify the following into either positive or negative, or entailment, neutral or contradiction:\",\n", ")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", "model = get_peft_model(model, peft_config)\n", "\n", "model = model.cuda()\n", "\n", "\n", "def send_to_device(batch):\n", " for i in batch:\n", " batch[i] = batch[i].cuda()\n", " return batch" ] }, { "cell_type": "code", "execution_count": null, "id": "eb112bc1-ffaf-49fa-a216-0d601ec304ee", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_sst2(split: str):\n", " examples = load_dataset(\"sst2\")[split]\n", " result_examples = []\n", " for example in examples:\n", " result_examples.append({})\n", "\n", " result_examples[-1][\"input\"] = example[\"sentence\"].strip() + \"\"\n", " result_examples[-1][\"output\"] = (\n", " f\"positive{tokenizer.eos_token}\" if example[\"label\"] == 1 else f\"negative{tokenizer.eos_token}\"\n", " )\n", " result_examples[-1][\"task_id\"] = 0\n", "\n", " return result_examples\n", "\n", "\n", "def get_mnli(split: str):\n", " examples = load_dataset(\"multi_nli\")[split]\n", " result_examples = []\n", " for example in examples:\n", " result_examples.append({})\n", "\n", " result_examples[-1][\"input\"] = example[\"premise\"].strip() + \" \" + example[\"hypothesis\"].strip() + \"\"\n", "\n", " if example[\"label\"] == 0:\n", " result_examples[-1][\"output\"] = f\"entailment{tokenizer.eos_token}\"\n", " elif example[\"label\"] == 1:\n", " result_examples[-1][\"output\"] = f\"neutral{tokenizer.eos_token}\"\n", " else:\n", " result_examples[-1][\"output\"] = f\"contradiction{tokenizer.eos_token}\"\n", "\n", " result_examples[-1][\"task_id\"] = 1\n", "\n", " return result_examples" ] }, { "cell_type": "code", "execution_count": null, "id": "e5a16ec4-8fef-4ba9-95b6-a661eb51e50c", "metadata": { "tags": [] }, "outputs": [], "source": [ "from typing import Tuple\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch\n", "\n", "\n", "class MyDataset(Dataset):\n", " def __init__(self, split: str, mode: str = \"source\") -> None:\n", " super().__init__()\n", "\n", " if split == \"train\":\n", " if mode == \"source\":\n", " self.examples = get_sst2(split) + get_mnli(split)\n", " elif mode == \"target\":\n", " self.examples = get_sst2(split)\n", " if split == \"val\":\n", " self.examples = get_sst2(\"validation\")\n", " if split == \"test\":\n", " self.examples = get_sst2(\"validation\")\n", "\n", " def __getitem__(self, index) -> dict:\n", " return self.examples[index]\n", "\n", " def __len__(self) -> int:\n", " return len(self.examples)\n", "\n", " def __getitem__(self, index) -> dict:\n", " return self.examples[index]\n", "\n", " def __len__(self) -> int:\n", " return len(self.examples)\n", "\n", "\n", "def collate_fn(batch: dict) -> Tuple[torch.Tensor, torch.Tensor]:\n", " input = [i[\"input\"] for i in batch]\n", " input = tokenizer(input, add_special_tokens=False, return_tensors=\"pt\", padding=True)\n", "\n", " output = [i[\"output\"] for i in batch]\n", " output = tokenizer(output, add_special_tokens=False, return_tensors=\"pt\", padding=True).input_ids\n", " output[output == tokenizer.pad_token_id] = -100\n", "\n", " task_ids = [i[\"task_id\"] for i in batch]\n", " task_ids = torch.tensor(task_ids)\n", "\n", " return {\n", " \"input_ids\": input.input_ids,\n", " \"attention_mask\": input.attention_mask,\n", " \"labels\": output,\n", " \"task_ids\": task_ids,\n", " }\n", "\n", "\n", "train = DataLoader(MyDataset(\"train\"), shuffle=True, batch_size=8, collate_fn=collate_fn)\n", "val = DataLoader(MyDataset(\"val\"), shuffle=False, batch_size=8, collate_fn=collate_fn)\n", "test = DataLoader(MyDataset(\"test\"), shuffle=False, batch_size=8, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "id": "fe0aec7b-f61e-4b00-a90e-c1201dc1f84c", "metadata": {}, "source": [ "## source training" ] }, { "cell_type": "code", "execution_count": null, "id": "cceecc94-f43a-4f62-8d45-926f2f02f36d", "metadata": { "tags": [] }, "outputs": [], "source": [ "from torch.optim.adamw import AdamW\n", "from transformers import get_cosine_schedule_with_warmup\n", "from tqdm import tqdm\n", "from sklearn.metrics import f1_score" ] }, { "cell_type": "code", "execution_count": null, "id": "eae5516b-73ab-44a8-a083-4e8de6127f30", "metadata": { "tags": [] }, "outputs": [], "source": [ "POSITIVE_TOKEN_ID = tokenizer(\" positive\", add_special_tokens=False)[\"input_ids\"][0]\n", "NEGATIVE_TOKEN_ID = tokenizer(\" negative\", add_special_tokens=False)[\"input_ids\"][0]\n", "\n", "\n", "def classify(batch):\n", " batch = send_to_device(batch)\n", " # we pass labels here since we need to generate and peft doesn't support generation yet.\n", " # No clue how to get around this\n", " scores = model(**batch).logits\n", " preds = []\n", " for i in range(scores.shape[0]):\n", " if scores[i, 0, POSITIVE_TOKEN_ID] > scores[i, 0, NEGATIVE_TOKEN_ID]:\n", " preds.append(POSITIVE_TOKEN_ID)\n", " else:\n", " preds.append(NEGATIVE_TOKEN_ID)\n", " return preds\n", "\n", "\n", "@torch.inference_mode()\n", "def evaluate(model, data):\n", " loss = 0\n", " preds = []\n", " golds = []\n", "\n", " for batch in tqdm(data):\n", " batch = send_to_device(batch)\n", " loss += model(**batch).loss\n", " golds.extend(batch[\"labels\"][:, 0].tolist())\n", " preds.extend(classify(batch))\n", "\n", " return loss / len(val), f1_score(golds, preds, pos_label=POSITIVE_TOKEN_ID)\n", "\n", "\n", "optimizer = AdamW(model.parameters(), lr=1e-4)\n", "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train))\n", "\n", "n = 1000\n", "step = 0\n", "train_ = tqdm(train)\n", "\n", "val_loss, f1 = evaluate(model, val)\n", "print(\n", " f\"\"\"\n", "before source training\n", "val loss = {val_loss}\n", "f1 = {f1}\"\"\"\n", ")\n", "\n", "for batch in train_:\n", " if step % n == 0:\n", " val_loss, f1 = evaluate(model, val)\n", " print(\n", " f\"\"\"\n", "step = {step}\n", "val loss = {val_loss}\n", "f1 = {f1}\"\"\"\n", " )\n", " model.save_pretrained(f\"checkpoints_source/{step}\")\n", "\n", " step += 1\n", " batch = send_to_device(batch)\n", " loss = model(**batch).loss\n", " loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", " train_.set_postfix(train_loss=loss)" ] }, { "cell_type": "markdown", "id": "74168ef3-66f3-41a7-a40b-7840b103fbf9", "metadata": {}, "source": [ "## target training" ] }, { "cell_type": "code", "execution_count": null, "id": "b09fd456-163e-4dc1-b24d-f2d0d349036c", "metadata": { "tags": [] }, "outputs": [], "source": [ "train = DataLoader(MyDataset(\"train\", \"target\"), shuffle=True, batch_size=8, collate_fn=collate_fn)\n", "val = DataLoader(MyDataset(\"val\", \"target\"), shuffle=False, batch_size=8, collate_fn=collate_fn)\n", "test = DataLoader(MyDataset(\"test\", \"target\"), shuffle=False, batch_size=8, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "id": "4a539944-f16c-4c3f-bb4a-7b5d9a6042e2", "metadata": {}, "source": [ "#### create a fresh model" ] }, { "cell_type": "code", "execution_count": null, "id": "5520d904-aa6c-4654-9335-ed4e7d76cba2", "metadata": { "tags": [] }, "outputs": [], "source": [ "peft_config = MultitaskPromptTuningConfig(\n", " tokenizer_name_or_path=model_name,\n", " num_tasks=1,\n", " task_type=TaskType.SEQ_2_SEQ_LM,\n", " prompt_tuning_init=MultitaskPromptTuningInit.EXACT_SOURCE_TASK,\n", " prompt_tuning_init_state_dict_path=\"checkpoints_source/50000/adapter_model.bin\",\n", " num_virtual_tokens=50,\n", " num_transformer_submodules=1,\n", ")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", "model = get_peft_model(model, peft_config)\n", "\n", "model = model.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "dfa39c2d-d1c5-4ed4-90f8-26e8e324371c", "metadata": { "tags": [] }, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=1e-4)\n", "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train))\n", "\n", "n = 1000\n", "step = 0\n", "train_ = tqdm(train)\n", "\n", "val_loss, f1 = evaluate(model, val)\n", "print(\n", " f\"\"\"\n", "before target training\n", "val loss = {val_loss}\n", "f1 = {f1}\"\"\"\n", ")\n", "\n", "for batch in train_:\n", " if step % n == 0:\n", " val_loss, f1 = evaluate(model, val)\n", " print(\n", " f\"\"\"\n", "step = {step}\n", "val loss = {val_loss}\n", "f1 = {f1}\"\"\"\n", " )\n", " model.save_pretrained(f\"checkpoints_target/{step}\")\n", "\n", " step += 1\n", " batch = send_to_device(batch)\n", " loss = model(**batch).loss\n", " loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", " train_.set_postfix(train_loss=loss)" ] }, { "cell_type": "code", "execution_count": null, "id": "b6a6eeda-1e09-49a6-8845-cd96c8573145", "metadata": { "tags": [] }, "outputs": [], "source": [ "# load last checkpoint for now\n", "from peft import set_peft_model_state_dict\n", "\n", "sd_6000 = torch.load(\"checkpoints_target/6000/adapter_model.bin\")\n", "set_peft_model_state_dict(model, sd_6000)\n", "\n", "# evaluate val\n", "val_loss, f1 = evaluate(model, val)\n", "print(\n", " f\"\"\"\n", "final\n", "val loss = {val_loss}\n", "f1 = {f1}\"\"\"\n", ")\n", "\n", "# evaluate test\n", "test_loss, f1 = evaluate(model, test)\n", "print(\n", " f\"\"\"\n", "final\n", "test loss = {test_loss}\n", "f1 = {f1}\"\"\"\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }