qwerrwe / tests /utils /test_models.py
jrc's picture
Add shifted sparse attention (#973) [skip-ci]
1d70f24 unverified
raw
history blame
No virus
1.12 kB
"""Module for testing models utils file."""
import unittest
from unittest.mock import patch
import pytest
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model
class ModelsUtilsTest(unittest.TestCase):
"""Testing module for models utils."""
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
"base_model": "",
"model_type": "LlamaForCausalLM",
}
)
# Mock out call to HF hub
with patch(
"axolotl.utils.models.load_model_config"
) as mocked_load_model_config:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="")
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)