{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "a9935ae2", "metadata": { "id": "a9935ae2" }, "outputs": [], "source": [ "import argparse\n", "import os\n", "\n", "import torch\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "import peft\n", "\n", "import evaluate\n", "from datasets import load_dataset\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": 3, "id": "e3b13308", "metadata": { "id": "e3b13308" }, "outputs": [], "source": [ "batch_size = 8\n", "model_name_or_path = \"roberta-large\"\n", "task = \"mrpc\"\n", "peft_type = peft.PeftType.IA3\n", "device = \"cuda\"\n", "num_epochs = 12" ] }, { "cell_type": "code", "execution_count": 4, "id": "0526f571", "metadata": { "id": "0526f571" }, "outputs": [], "source": [ "# peft_config = LoraConfig(task_type=\"SEQ_CLS\", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)\n", "peft_config = peft.IA3Config(task_type=\"SEQ_CLS\", inference_mode=False)\n", "lr = 1e-3" ] }, { "cell_type": "code", "execution_count": 5, "id": "c2697d07", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 489, "referenced_widgets": [ "6ea6ff70fa164264aef9efce9f921f10", "ecc2102ede2d4c8b94ce66b247054c96", "e64ca48867434ca2944dcb2b1c70c02c", "7ac75048225f4a4bbedac97965cf9837", "332971de5e894c8ca866c311cc6e180c", "bdf0999dfddc43ca8e04ceaac628064c", "c289713b179641bd915b3c334208197a", "e3d36dd4ddcb4287b8d8c942a26dc478", "0b20fa2eb59749e1b564df70f2378984", "d8e645f1697d46e0ac23f70e14498fab", "8dc1453856d549c4a1a688818447dd59", "d267c0cf26a8466d9938861ff5272a1e", "561f61d1cec94f5abbff3f8746c3fd96", "a5c93fac02884914ae8e140e0c0d2a17", "96c4ecd7376c43ee9a4ba270851d6fff", "5062d6f6feb34835afaf6d452900d514", "a05844fa4675494cb7bbe48f40f1aaac", "b241da7afea94fc0a6b407a0a78a8355", "74e056313e9e4a92bfaa22019ac1e58e", "938a5b44d90140e29ba33628a2215f2a", "1f27766bc2d941f6a1d03fc69a6026a3", "1c5e43a201f0460f88b53739bc1eaa43", "24e236a4360e416a8e5c20d887274bcc", "7011752f7e0a422883bf0f218f6941c3", "248fe45eee37449982520a890696e6d4", "5ec5d0d9191047608f61cb78563f7641", "d7c2b00fd90147528b00a29552f29b44", "ea307bb7a9ad46a484330b1daf708169", "27009ed667c242488773ce3fcc58360a", "faf0c550adc5402bbee71b027822ce9c", "9c57da41dfc04fe4a4437097285599bf", "d4453448f8b04f0ea76c71887bd33a8c", "ff5172169c794d40b6c433da642d54a9", "788a143aad46467789acf08762fbf39f", "32ba35144ab34b0ebb7cfb75b86ccdd4", "3ef39b223ab74d788223970ec0da21e6", "be40292e613949a0ba1bfd1846cdab92", "a6ab0e37d063407fa4adc11f2f0299da", "71c3a5a515a949bfb2290612bfb2d05f", "1232101b618d43f5889e4ad81fa24514", "2204ca891c3940879abde0f55fbadf03", "58f2da1b793b44b09cedb5e0a3a1ef02", "51ec9bdff1834e4f8390f0eaf7d4aeb8", "71b4a798dc374a72817f6c118f2f05b8", "a439e5ecb0a040ca8188c6d74ff643a9", "710f512e4a2b40cb85805b33a7dc42e4", "e5994b94b84143f4a06c49a45a078bd0", "e4f935e4e8e84320b2381aaf3493962c", "44eb5aeeb0004f799e79d3ed51cf37d4", "0c488289ee494a8d99d1f02a13729382", "49ccb14ddd874e798371494942725128", "015132b18da54a8f89a83cd9bf6fd17c", "39f26160766f48798761079527e14396", "5408dfdfc5624ff5b1c0b2e494dd4b35", "9c50277ed18a424a9da98707cf726d29", "11ba7f922c58474d9cb8b8c7a22caddc", "598fc42dfed04aadaeb38539dd259871", "fb05d1ece2ba4963bcbf99f89296c709", "e3fe73a6ffcf4d089911b4149b9b4512", "45e1b59770f946dba06511f9b5d1ad23", "2c512bc0a21a4f6684aa0593a969e6cb", "376f00f38b46434e922c3f6f7dc4a85c", "e815b0749289411eb680c56e4fd39dae", "33e9821903f84551a2e56c0586a5332b", "0b2e59615f80452cbe25dfd467099b84", "45128b02652b4c20bf448db8495d76d2", "d1997b08e5284008afce56a5ed6347be", "730111ee99b4470c81ab4ade07b30352", "562a55ffabef4d218c78c5dcc6665484", "59e1d01621f4404ca2102d00e2b01f87", "1896a8f160ea45e5ac2a88970feecdb9", "7d6d21774fec4eb0bab3bc2bfa2708d6", "9882a67915f8475c9f4b8a5b15d64e50", "47ffc35023144c4fa6b93e73b7a2ee60", "16f8d52980dc41e89ed3bfb9172420a6", "2bc5432ab9464f7b9192aa1f5a26a1b2", "80bcbdb92af34c9889b7d105affe34ba", "6786c3895b594301b2a7d4ed2767cb35", "bf35a5a5d9b840a18ef10938d23fce0c", "ad4872e198ee44f8b83323d891347ffa", "0f29f96aec7e44a1b97cda0fcbe17665", "dbb4032fab4d49a5b250be57a4d51fb0", "af185d9bedf64d54bc77f7f6e7c448f4", "4d9d717d2226444094e9753fdc843849", "3d323bc6e1fd4091a0ed2ffd521f2ec7", "2616b6102108476f8bf9e3d35b63494d", "844ea1050893482785c1b68150f4ab20", "92d023972a204222b59c12fc4b4d3bcf", "7900796766b946da886338653b495533", "84b50ae864e0463d98efa792e149e712", "802dc985a31e4febb8eafa4682e242ec", "f31d9d6e2d3140eeb821153a7b69b90a", "3f2d9eba845341da97fa1b957e72de5e", "d3638a2985fc4fdfbf239914dbc32fe8", "a9d237f2f62e4a839abdec541715a5de", "b451310e94b74abda4e795a59cdda9ab", "537c7d2f4260498a82132cb9fae5daa5", "6f9efa5f778d4f029dca7b4d6817b4c4", "d29a5f422214434c9be3886ce0d1e918", "e0e903b3ae8044f4a58ecd70825f36ec", "99e8679e71f4443c92f13971e8885d38", "55a0ca0754c94f0bba6ede50b7fb7ea8", "fae83f7a625448e788ffdb1a13d2d530", "3086db94647a41569b6f091e4f11f3cb", "24062742dc6749a2aab19cbfd1e11684", "1b6c8de76fd948b5b0eba8346f6e0bec", "276fc5825a624a3a9026301620682c6c", "5c10a16929ce40e48361efa44e96a7f8", "53f9d507f4024b2080591215d3c4bab4", "c77cdd4479294ce7bf3f3ef271e95b1d", "e66a0c2552ff46b5a41cf8d803114b8f", "897708ca40da48f3ae846710b4e7f7f0", "1aa34d9a0b16444baadf2ed7e43b72e5", "01f39b548eff49aea6ba0efabb4486de", "46fc201246804f28946fd0c59e09c4ac", "e6c085b451a346b6a3b82c67d7b9e9e2", "29a8cce77d2745038f2c742d17254139", "326b48d7d2e94defa52858171177e7d7", "9a3dcc1c2fe542b7b912419935d9dd9a", "b82b0c29bd53432baa5596b4f1aaf076", "a8ccd811bba848d2ad2f84d2cdf9ef03", "9ae4e6a6543e4da6bba06175aeef3ea1", "fa9675039cc443f791dcfbc0bdca065c", "5b0a0578afad40868bb71707c05e0335", "d1a82e42e866449cbb70c1f68ac1bc03", "bd12886eee644913ba83fd4ebb6b62ee", "fc11a9d0c0b7409cb2f61fcabac2bfa6", "9cf73d46c2f04bf398a4564a56f03bd6", "f46d3699775b4d7193adeebb4f18a34f", "5997f14e0a3044249b9cabc2b307e3a4", "5ba9f4bd64bd4e1e905887760f90ae3e", "1a8e75c718d14f4c8f99de1c8efa84b5", "786b2430a42942f1aeae253861820dcb", "e2741973e91745ea930d3cc23070cf52", "b9f18367c54b4203bee70c680ae9cdde", "7db82bb6fa2949b3a6bc9eb152fc3af1", "136a32dbeef646df944a8b59cb00c0c4", "c2fa0ef9f45b447e9dfe5c576428c714", "241dd21eeb5843ef8433e47b415c5b62", "c9837f88650844ea94442ad1c5682972", "963da028a1594c1bbd223e93832f44bd", "d981db72ab2745a598ed45a8762d5fcd", "73379eb4954d4be6ab90008addd7d3bd", "007f2f512694405c9245edd5e1f58551", "9aaeeb1213854768a6af51f8db54f6b6", "e6e2192c0a904bb6850c6bc68b579995", "608a68c6f651419d8fd210044b4561cf", "ac9d6255521f428fa7fda5735d71dbec", "a39ea243b7964c298eae71e9eaf32e17", "bf7f7cb363df4384ade192666b77c715", "841f3571fcec42ad8e1e43997567788c", "a246e6c60c054b4383c510c31b255a87", "da3e335ddc9846ea832c75d69684575f", "daa8ce62f0204a11bde7cd9e31cc2fa1", "6b937f0053b64c67af74d44bcb6e51b7", "72f55658a145483cb90668bf7b6f6c8a", "357c082907e140b58c807a394446d811", "f127594b694f4d2fafb3872e8190b1d6", "f9c8bd12201a4c8a86b2bb3c1a14d74e", "dc55ec8e57984efcb0f269ef0ef41c02", "cd07261ef21a4b859de53f244df33f2a", "c47de12dd692434ab9497e8a0d2a19ba", "963b22b2ef1e45f5ba8320f72ea6a83b", "3e2ca33b7473499d8314238ec245cd90", "e3313a7fb356432bada289ad346f1bdb", "75ee099566ca46eb8a75475246aaab01", "e94dd1ea928f43b880af21ced7f11d14", "ad84cdcfa07f4f608dccd10395d35e79", "8bd3a235ae654b1cbaae36aeb1a62e70", "ebe2f83a5783401ab501cdaf9c4e2ad5", "fe20644617844cf7a471e22a99ca7b5e", "d5502b7dcf784d289fc748881924334c", "6bf5b27d3c3c407eb10f8d9d6e8e8d22", "47480982ce83440997297e605cdf8a31", "32764784cc9644198c622f6706db6836", "83bae5949ec448c6ad6f68a7b8d3d436", "a018f7c156584eb3833b6aba3710b3c0", "c2a078c69d3f45ca9895e0e7f95aaf2b", "8c55316c54d44bd4b75a8457e2d8c595", "1824b7b748ee4763b5098cd915107f63", "87b93c14cdae477ab49522354631e82b", "a8b5d642ed654c0f85f3a9610c68b754", "8750ab964b5444d49db4fe8542964d8c", "c994a28df165445bbf0d80c165bc5a0b", "1a2897c84d454c1a9c115aef178f4fcf", "b668d56ab549478984ad14c22e040e47", "bac855f4c7044fe88bcd74170e13f103", "f871eeaa54cf451b8a2de64eed90d5b6", "39492d8b91f64e97bc8caead77957508", "1487d054f1824f739c93c00621591bd0", "cca43d628ff94369bc9713cbad616adb", "6aa4308f0cb348419f1217e48ef2dbd7", "d9db044ff31f4856a4d25e57fe9882bc", "f76b651681d74b2eab0ab07f43983a2a", "1e3adfe9c9e34d4abe74b65e298c0e7b", "f91ed6a931c247b5903304a28998633d", "34e576af7aab4ef795239b1d9a281b15", "9d8e3a4fe5864a4ca229ba6e092181e4", "e41888b1dea140d18be25efc6d99d0c3", "d63577f0ae724938952f9b681a89512a", "83623082d4e9457a9173db3129154f94", "4c1798c3d0cb43949a1ed7519b57d3fc", "6513f2b58d654a488b132abd8c83c9b6", "c4aa7cee6e5d4a89a2c33b9ce2e9c489", "b42e835e0ba24a08b5815234a22b4da6", "e8a53956f6ad46d58ec3a9a039f2303e", "91f07c6089354ceeb4e628ef1791d2d6", "c137c84ae40049ffba7d1f4d77b21de2", "81f8db0dcdbc4e5593fa3fbb6c2b9361" ] }, "id": "c2697d07", "outputId": "c8318b3d-b6a0-4f0d-9903-54beb2baac75" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading (…)lve/main/config.json: 0%| | 0.00/482 [00:00 use the model max length (it's actually the default)\n", " outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n", " return outputs\n", "\n", "\n", "tokenized_datasets = datasets.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n", ")\n", "\n", "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n", "# transformers library\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", "\n", "\n", "def collate_fn(examples):\n", " return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n", "\n", "\n", "# Instantiate dataloaders.\n", "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n", "eval_dataloader = DataLoader(\n", " tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n", ")\n", "test_dataloader = DataLoader(tokenized_datasets[\"test\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 6, "id": "2ed5ac74", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "0cecb897c86c4892b94a1990ab08a926", "b8af0294819e4280ad41fa1c11006adf", "c7530d63b2f745e799713284abacbd2c", "12a1e302a69543c5bc0e0a66be008ca0", "9c372e9e9b20433faed8530ca0f4424c", "b07f26a21325493cac19113f1aa1ee96", "fbdf6c544fb54294903524a69384e773", "6c6f2223243b4a7485aef7fbbfe07668", "a93509d61ac94628a74bc0f98c0eec06", "3a8de0eb7db44647a734590b6b351b44", "64d8affd2e854a1c9043fec7ca8a2796" ] }, "id": "2ed5ac74", "outputId": "18ea15ac-ed8d-4d80-b166-706681ee49ab" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading model.safetensors: 0%| | 0.00/1.42G [00:00