diff --git "a/IA3.ipynb" "b/IA3.ipynb" new file mode 100644--- /dev/null +++ "b/IA3.ipynb" @@ -0,0 +1,8070 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "a9935ae2", + "metadata": { + "id": "a9935ae2" + }, + "outputs": [], + "source": [ + "import argparse\n", + "import os\n", + "\n", + "import torch\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "import peft\n", + "\n", + "import evaluate\n", + "from datasets import load_dataset\n", + "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e3b13308", + "metadata": { + "id": "e3b13308" + }, + "outputs": [], + "source": [ + "batch_size = 8\n", + "model_name_or_path = \"roberta-large\"\n", + "task = \"mrpc\"\n", + "peft_type = peft.PeftType.IA3\n", + "device = \"cuda\"\n", + "num_epochs = 12" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0526f571", + "metadata": { + "id": "0526f571" + }, + "outputs": [], + "source": [ + "# peft_config = LoraConfig(task_type=\"SEQ_CLS\", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)\n", + "peft_config = peft.IA3Config(task_type=\"SEQ_CLS\", inference_mode=False)\n", + "lr = 1e-3" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c2697d07", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 489, + "referenced_widgets": [ + "6ea6ff70fa164264aef9efce9f921f10", + "ecc2102ede2d4c8b94ce66b247054c96", + "e64ca48867434ca2944dcb2b1c70c02c", + "7ac75048225f4a4bbedac97965cf9837", + "332971de5e894c8ca866c311cc6e180c", + "bdf0999dfddc43ca8e04ceaac628064c", + "c289713b179641bd915b3c334208197a", + "e3d36dd4ddcb4287b8d8c942a26dc478", + "0b20fa2eb59749e1b564df70f2378984", + "d8e645f1697d46e0ac23f70e14498fab", + "8dc1453856d549c4a1a688818447dd59", + "d267c0cf26a8466d9938861ff5272a1e", + "561f61d1cec94f5abbff3f8746c3fd96", + "a5c93fac02884914ae8e140e0c0d2a17", + "96c4ecd7376c43ee9a4ba270851d6fff", + "5062d6f6feb34835afaf6d452900d514", + "a05844fa4675494cb7bbe48f40f1aaac", + "b241da7afea94fc0a6b407a0a78a8355", + "74e056313e9e4a92bfaa22019ac1e58e", + "938a5b44d90140e29ba33628a2215f2a", + "1f27766bc2d941f6a1d03fc69a6026a3", + "1c5e43a201f0460f88b53739bc1eaa43", + "24e236a4360e416a8e5c20d887274bcc", + "7011752f7e0a422883bf0f218f6941c3", + "248fe45eee37449982520a890696e6d4", + "5ec5d0d9191047608f61cb78563f7641", + "d7c2b00fd90147528b00a29552f29b44", + "ea307bb7a9ad46a484330b1daf708169", + "27009ed667c242488773ce3fcc58360a", + "faf0c550adc5402bbee71b027822ce9c", + "9c57da41dfc04fe4a4437097285599bf", + "d4453448f8b04f0ea76c71887bd33a8c", + "ff5172169c794d40b6c433da642d54a9", + "788a143aad46467789acf08762fbf39f", + "32ba35144ab34b0ebb7cfb75b86ccdd4", + "3ef39b223ab74d788223970ec0da21e6", + "be40292e613949a0ba1bfd1846cdab92", + "a6ab0e37d063407fa4adc11f2f0299da", + "71c3a5a515a949bfb2290612bfb2d05f", + "1232101b618d43f5889e4ad81fa24514", + "2204ca891c3940879abde0f55fbadf03", + "58f2da1b793b44b09cedb5e0a3a1ef02", + "51ec9bdff1834e4f8390f0eaf7d4aeb8", + "71b4a798dc374a72817f6c118f2f05b8", + "a439e5ecb0a040ca8188c6d74ff643a9", + "710f512e4a2b40cb85805b33a7dc42e4", + "e5994b94b84143f4a06c49a45a078bd0", + "e4f935e4e8e84320b2381aaf3493962c", + "44eb5aeeb0004f799e79d3ed51cf37d4", + "0c488289ee494a8d99d1f02a13729382", + "49ccb14ddd874e798371494942725128", + "015132b18da54a8f89a83cd9bf6fd17c", + "39f26160766f48798761079527e14396", + "5408dfdfc5624ff5b1c0b2e494dd4b35", + "9c50277ed18a424a9da98707cf726d29", + "11ba7f922c58474d9cb8b8c7a22caddc", + "598fc42dfed04aadaeb38539dd259871", + "fb05d1ece2ba4963bcbf99f89296c709", + "e3fe73a6ffcf4d089911b4149b9b4512", + "45e1b59770f946dba06511f9b5d1ad23", + "2c512bc0a21a4f6684aa0593a969e6cb", + "376f00f38b46434e922c3f6f7dc4a85c", + "e815b0749289411eb680c56e4fd39dae", + "33e9821903f84551a2e56c0586a5332b", + "0b2e59615f80452cbe25dfd467099b84", + "45128b02652b4c20bf448db8495d76d2", + "d1997b08e5284008afce56a5ed6347be", + "730111ee99b4470c81ab4ade07b30352", + "562a55ffabef4d218c78c5dcc6665484", + "59e1d01621f4404ca2102d00e2b01f87", + "1896a8f160ea45e5ac2a88970feecdb9", + "7d6d21774fec4eb0bab3bc2bfa2708d6", + "9882a67915f8475c9f4b8a5b15d64e50", + "47ffc35023144c4fa6b93e73b7a2ee60", + "16f8d52980dc41e89ed3bfb9172420a6", + "2bc5432ab9464f7b9192aa1f5a26a1b2", + "80bcbdb92af34c9889b7d105affe34ba", + "6786c3895b594301b2a7d4ed2767cb35", + "bf35a5a5d9b840a18ef10938d23fce0c", + "ad4872e198ee44f8b83323d891347ffa", + "0f29f96aec7e44a1b97cda0fcbe17665", + "dbb4032fab4d49a5b250be57a4d51fb0", + "af185d9bedf64d54bc77f7f6e7c448f4", + "4d9d717d2226444094e9753fdc843849", + "3d323bc6e1fd4091a0ed2ffd521f2ec7", + "2616b6102108476f8bf9e3d35b63494d", + "844ea1050893482785c1b68150f4ab20", + "92d023972a204222b59c12fc4b4d3bcf", + "7900796766b946da886338653b495533", + "84b50ae864e0463d98efa792e149e712", + "802dc985a31e4febb8eafa4682e242ec", + "f31d9d6e2d3140eeb821153a7b69b90a", + "3f2d9eba845341da97fa1b957e72de5e", + "d3638a2985fc4fdfbf239914dbc32fe8", + "a9d237f2f62e4a839abdec541715a5de", + "b451310e94b74abda4e795a59cdda9ab", + "537c7d2f4260498a82132cb9fae5daa5", + "6f9efa5f778d4f029dca7b4d6817b4c4", + "d29a5f422214434c9be3886ce0d1e918", + "e0e903b3ae8044f4a58ecd70825f36ec", + "99e8679e71f4443c92f13971e8885d38", + "55a0ca0754c94f0bba6ede50b7fb7ea8", + "fae83f7a625448e788ffdb1a13d2d530", + "3086db94647a41569b6f091e4f11f3cb", + "24062742dc6749a2aab19cbfd1e11684", + "1b6c8de76fd948b5b0eba8346f6e0bec", + "276fc5825a624a3a9026301620682c6c", + "5c10a16929ce40e48361efa44e96a7f8", + "53f9d507f4024b2080591215d3c4bab4", + "c77cdd4479294ce7bf3f3ef271e95b1d", + "e66a0c2552ff46b5a41cf8d803114b8f", + "897708ca40da48f3ae846710b4e7f7f0", + "1aa34d9a0b16444baadf2ed7e43b72e5", + "01f39b548eff49aea6ba0efabb4486de", + "46fc201246804f28946fd0c59e09c4ac", + "e6c085b451a346b6a3b82c67d7b9e9e2", + "29a8cce77d2745038f2c742d17254139", + "326b48d7d2e94defa52858171177e7d7", + "9a3dcc1c2fe542b7b912419935d9dd9a", + "b82b0c29bd53432baa5596b4f1aaf076", + "a8ccd811bba848d2ad2f84d2cdf9ef03", + "9ae4e6a6543e4da6bba06175aeef3ea1", + "fa9675039cc443f791dcfbc0bdca065c", + "5b0a0578afad40868bb71707c05e0335", + "d1a82e42e866449cbb70c1f68ac1bc03", + "bd12886eee644913ba83fd4ebb6b62ee", + "fc11a9d0c0b7409cb2f61fcabac2bfa6", + "9cf73d46c2f04bf398a4564a56f03bd6", + "f46d3699775b4d7193adeebb4f18a34f", + "5997f14e0a3044249b9cabc2b307e3a4", + "5ba9f4bd64bd4e1e905887760f90ae3e", + "1a8e75c718d14f4c8f99de1c8efa84b5", + "786b2430a42942f1aeae253861820dcb", + "e2741973e91745ea930d3cc23070cf52", + "b9f18367c54b4203bee70c680ae9cdde", + "7db82bb6fa2949b3a6bc9eb152fc3af1", + "136a32dbeef646df944a8b59cb00c0c4", + "c2fa0ef9f45b447e9dfe5c576428c714", + "241dd21eeb5843ef8433e47b415c5b62", + "c9837f88650844ea94442ad1c5682972", + "963da028a1594c1bbd223e93832f44bd", + "d981db72ab2745a598ed45a8762d5fcd", + "73379eb4954d4be6ab90008addd7d3bd", + "007f2f512694405c9245edd5e1f58551", + "9aaeeb1213854768a6af51f8db54f6b6", + "e6e2192c0a904bb6850c6bc68b579995", + "608a68c6f651419d8fd210044b4561cf", + "ac9d6255521f428fa7fda5735d71dbec", + "a39ea243b7964c298eae71e9eaf32e17", + "bf7f7cb363df4384ade192666b77c715", + "841f3571fcec42ad8e1e43997567788c", + "a246e6c60c054b4383c510c31b255a87", + "da3e335ddc9846ea832c75d69684575f", + "daa8ce62f0204a11bde7cd9e31cc2fa1", + "6b937f0053b64c67af74d44bcb6e51b7", + "72f55658a145483cb90668bf7b6f6c8a", + "357c082907e140b58c807a394446d811", + "f127594b694f4d2fafb3872e8190b1d6", + "f9c8bd12201a4c8a86b2bb3c1a14d74e", + "dc55ec8e57984efcb0f269ef0ef41c02", + "cd07261ef21a4b859de53f244df33f2a", + "c47de12dd692434ab9497e8a0d2a19ba", + "963b22b2ef1e45f5ba8320f72ea6a83b", + "3e2ca33b7473499d8314238ec245cd90", + "e3313a7fb356432bada289ad346f1bdb", + "75ee099566ca46eb8a75475246aaab01", + "e94dd1ea928f43b880af21ced7f11d14", + "ad84cdcfa07f4f608dccd10395d35e79", + "8bd3a235ae654b1cbaae36aeb1a62e70", + "ebe2f83a5783401ab501cdaf9c4e2ad5", + "fe20644617844cf7a471e22a99ca7b5e", + "d5502b7dcf784d289fc748881924334c", + "6bf5b27d3c3c407eb10f8d9d6e8e8d22", + "47480982ce83440997297e605cdf8a31", + "32764784cc9644198c622f6706db6836", + "83bae5949ec448c6ad6f68a7b8d3d436", + "a018f7c156584eb3833b6aba3710b3c0", + "c2a078c69d3f45ca9895e0e7f95aaf2b", + "8c55316c54d44bd4b75a8457e2d8c595", + "1824b7b748ee4763b5098cd915107f63", + "87b93c14cdae477ab49522354631e82b", + "a8b5d642ed654c0f85f3a9610c68b754", + "8750ab964b5444d49db4fe8542964d8c", + "c994a28df165445bbf0d80c165bc5a0b", + "1a2897c84d454c1a9c115aef178f4fcf", + "b668d56ab549478984ad14c22e040e47", + "bac855f4c7044fe88bcd74170e13f103", + "f871eeaa54cf451b8a2de64eed90d5b6", + "39492d8b91f64e97bc8caead77957508", + "1487d054f1824f739c93c00621591bd0", + "cca43d628ff94369bc9713cbad616adb", + "6aa4308f0cb348419f1217e48ef2dbd7", + "d9db044ff31f4856a4d25e57fe9882bc", + "f76b651681d74b2eab0ab07f43983a2a", + "1e3adfe9c9e34d4abe74b65e298c0e7b", + "f91ed6a931c247b5903304a28998633d", + "34e576af7aab4ef795239b1d9a281b15", + "9d8e3a4fe5864a4ca229ba6e092181e4", + "e41888b1dea140d18be25efc6d99d0c3", + "d63577f0ae724938952f9b681a89512a", + "83623082d4e9457a9173db3129154f94", + "4c1798c3d0cb43949a1ed7519b57d3fc", + "6513f2b58d654a488b132abd8c83c9b6", + "c4aa7cee6e5d4a89a2c33b9ce2e9c489", + "b42e835e0ba24a08b5815234a22b4da6", + "e8a53956f6ad46d58ec3a9a039f2303e", + "91f07c6089354ceeb4e628ef1791d2d6", + "c137c84ae40049ffba7d1f4d77b21de2", + "81f8db0dcdbc4e5593fa3fbb6c2b9361" + ] + }, + "id": "c2697d07", + "outputId": "c8318b3d-b6a0-4f0d-9903-54beb2baac75" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading (…)lve/main/config.json: 0%| | 0.00/482 [00:00 use the model max length (it's actually the default)\n", + " outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n", + " return outputs\n", + "\n", + "\n", + "tokenized_datasets = datasets.map(\n", + " tokenize_function,\n", + " batched=True,\n", + " remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n", + ")\n", + "\n", + "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n", + "# transformers library\n", + "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", + "\n", + "\n", + "def collate_fn(examples):\n", + " return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n", + "\n", + "\n", + "# Instantiate dataloaders.\n", + "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n", + "eval_dataloader = DataLoader(\n", + " tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n", + ")\n", + "test_dataloader = DataLoader(tokenized_datasets[\"test\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2ed5ac74", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "0cecb897c86c4892b94a1990ab08a926", + "b8af0294819e4280ad41fa1c11006adf", + "c7530d63b2f745e799713284abacbd2c", + "12a1e302a69543c5bc0e0a66be008ca0", + "9c372e9e9b20433faed8530ca0f4424c", + "b07f26a21325493cac19113f1aa1ee96", + "fbdf6c544fb54294903524a69384e773", + "6c6f2223243b4a7485aef7fbbfe07668", + "a93509d61ac94628a74bc0f98c0eec06", + "3a8de0eb7db44647a734590b6b351b44", + "64d8affd2e854a1c9043fec7ca8a2796" + ] + }, + "id": "2ed5ac74", + "outputId": "18ea15ac-ed8d-4d80-b166-706681ee49ab" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading model.safetensors: 0%| | 0.00/1.42G [00:00