File size: 1,149 Bytes
0402d19
 
 
b3a61e8
0402d19
 
00568c1
0402d19
00568c1
b3a61e8
0402d19
 
 
 
 
 
 
 
 
6dc68a6
0402d19
 
 
 
 
b3a61e8
 
 
 
 
 
 
 
 
 
00568c1
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
helper utils for tests
"""
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path


def with_temp_dir(test_func):
    @wraps(test_func)
    def wrapper(*args, **kwargs):
        # Create a temporary directory
        temp_dir = tempfile.mkdtemp()
        try:
            # Pass the temporary directory to the test function
            test_func(*args, temp_dir=temp_dir, **kwargs)
        finally:
            # Clean up the directory after the test
            shutil.rmtree(temp_dir)

    return wrapper


def most_recent_subdir(path):
    base_path = Path(path)
    subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
    if not subdirectories:
        return None
    subdir = max(subdirectories, key=os.path.getctime)

    return subdir


def require_torch_2_1_1(test_case):
    """
    Decorator marking a test that requires torch >= 2.1.1
    """

    def is_min_2_1_1():
        torch_version = version("torch")
        return torch_version >= "2.1.1"

    return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)