File size: 545 Bytes
eeeb3c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#!/usr/bin/env python3

import sys
from collections import OrderedDict

import torch

# Load and keep backup
m_input = torch.load("2_Dense/pytorch_model.bin")
torch.save(m_input, "2_Dense/pytorch_model.bin.bak")

# Mappings
rename = {"layer.weight": "linear.weight"}

# Output
m_output = OrderedDict()
for key, params in m_input.items():
    dst = key
    if key in rename:
        print(f"Mapping {key} to {rename[key]}", file=sys.stderr)
        dst = rename[key]

    m_output[dst] = params

torch.save(m_output, "2_Dense/pytorch_model.bin")