Benjamin Bossan commited on
Commit
31a1df6
1 Parent(s): 73bf817

Initial commit

Browse files
Files changed (10) hide show
  1. __init__.py +0 -0
  2. app.py +12 -0
  3. cat.png +0 -0
  4. edit.py +255 -0
  5. make-data.py +26 -0
  6. packages.txt +1 -0
  7. requirements.txt +7 -0
  8. start.py +184 -0
  9. tasks.py +189 -0
  10. utils.py +87 -0
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from start import start_input_form
4
+ from edit import edit_input_form
5
+
6
+ st.header("Skops model card creator")
7
+ st.markdown("---")
8
+
9
+ if not st.session_state.get("model_card"):
10
+ start_input_form()
11
+ else:
12
+ edit_input_form()
cat.png ADDED
edit.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import reprlib
4
+ from pathlib import Path
5
+ from tempfile import mkdtemp
6
+
7
+ import streamlit as st
8
+ from huggingface_hub import hf_hub_download
9
+ from skops import card
10
+ from skops.card._model_card import PlotSection, split_subsection_names
11
+
12
+ from utils import iterate_key_section_content, process_card_for_rendering
13
+ from tasks import AddSectionTask, AddFigureTask, DeleteSectionTask, TaskState, UpdateFigureTask, UpdateSectionTask
14
+
15
+
16
+ arepr = reprlib.Repr()
17
+ arepr.maxstring = 24
18
+ tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files
19
+ hf_path = Path(mkdtemp(prefix="skops-")) # hf repo
20
+
21
+
22
+ def load_model_card_from_repo(repo_id: str) -> card.Card:
23
+ print("downloading model card")
24
+ path = hf_hub_download(repo_id, "README.md")
25
+ model_card = card.parse_modelcard(path)
26
+ return model_card
27
+
28
+
29
+ def _update_model_card(
30
+ model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool,
31
+ ) -> None:
32
+ # This is a very roundabout way to update the model card but it's necessary
33
+ # because of how streamlit handles session state. Basically, there have to
34
+ # be "key" arguments, which have to be retrieved from the session_state, as
35
+ # they are up-to-date. Just getting the Python variables is not enough, as
36
+ # they can be out of date.
37
+
38
+ # key names must match with those used in form
39
+ new_title = st.session_state[f"{key}.title"]
40
+ new_content = st.session_state[f"{key}.content"]
41
+
42
+ # determine if title is the same
43
+ old_title_split = split_subsection_names(section_name)
44
+ new_title_split = old_title_split[:-1] + [new_title]
45
+ is_title_same = old_title_split == new_title_split
46
+
47
+ # determine if content is the same
48
+ if is_fig:
49
+ if isinstance(new_content, PlotSection):
50
+ is_content_same = content == new_content
51
+ else:
52
+ is_content_same = not bool(new_content)
53
+ else:
54
+ is_content_same = content == new_content
55
+
56
+ if is_title_same and is_content_same:
57
+ return
58
+
59
+ if is_fig:
60
+ fpath = None
61
+ if new_content: # new figure uploaded
62
+ fname = new_content.name.replace(" ", "_")
63
+ fpath = tmp_path / fname
64
+ task = UpdateFigureTask(
65
+ model_card,
66
+ key=key,
67
+ old_name=section_name,
68
+ new_name=new_title,
69
+ data=new_content,
70
+ path=fpath,
71
+ )
72
+ else:
73
+ task = UpdateSectionTask(
74
+ model_card,
75
+ key=key,
76
+ old_name=section_name,
77
+ new_name=new_title,
78
+ old_content=content,
79
+ new_content=new_content,
80
+ )
81
+ st.session_state.task_state.add(task)
82
+
83
+
84
+ def _add_section(model_card: card.Card, key: str) -> None:
85
+ section_name = f"{key}/Untitled"
86
+ task = AddSectionTask(model_card, title=section_name, content="[More Information Needed]")
87
+ st.session_state.task_state.add(task)
88
+
89
+
90
+ def _add_figure(model_card: card.Card, key: str) -> None:
91
+ section_name = f"{key}/Untitled"
92
+ task = AddFigureTask(model_card, title=section_name, content="cat.png")
93
+ st.session_state.task_state.add(task)
94
+
95
+
96
+ def _delete_section(model_card: card.Card, key: str) -> None:
97
+ task = DeleteSectionTask(model_card, key=key)
98
+ st.session_state.task_state.add(task)
99
+
100
+
101
+ def _add_section_form(
102
+ model_card: card.Card, key: str, section_name: str, old_title: str, content: str
103
+ ) -> None:
104
+ with st.form(key, clear_on_submit=False):
105
+ st.header(section_name)
106
+ # setting the 'key' argument below to update the session_state
107
+ st.text_input("Section name", value=old_title, key=f"{key}.title")
108
+ st.text_area("Content", value=content, key=f"{key}.content")
109
+ is_fig = False
110
+ st.form_submit_button(
111
+ "Update",
112
+ on_click=_update_model_card,
113
+ args=(model_card, key, section_name, content, is_fig),
114
+ )
115
+
116
+
117
+ def _add_fig_form(
118
+ model_card: card.Card, key: str, section_name: str, old_title: str, content: str
119
+ ) -> None:
120
+ with st.form(key, clear_on_submit=False):
121
+ st.header(section_name)
122
+ # setting the 'key' argument below to update the session_state
123
+ st.text_input("Section name", value=old_title, key=f"{key}.title")
124
+ st.file_uploader("Upload image", key=f"{key}.content")
125
+ is_fig = True
126
+ st.form_submit_button(
127
+ "Update",
128
+ on_click=_update_model_card,
129
+ args=(model_card, key, section_name, content, is_fig),
130
+ )
131
+
132
+
133
+ def create_form_from_section(
134
+ model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool = False
135
+ ) -> None:
136
+ split_sections = split_subsection_names(section_name)
137
+ old_title = split_sections[-1]
138
+ if is_fig:
139
+ _add_fig_form(
140
+ model_card=model_card,
141
+ key=key,
142
+ section_name=section_name,
143
+ old_title=old_title,
144
+ content=content,
145
+ )
146
+ else:
147
+ _add_section_form(
148
+ model_card=model_card,
149
+ key=key,
150
+ section_name=section_name,
151
+ old_title=old_title,
152
+ content=content,
153
+ )
154
+
155
+ col_0, col_1, col_2 = st.columns([4, 2, 2])
156
+ with col_0:
157
+ st.button(
158
+ f"delete '{arepr.repr(old_title)}'",
159
+ on_click=_delete_section,
160
+ args=(model_card, key),
161
+ key=f"{key}.delete",
162
+ )
163
+ with col_1:
164
+ st.button(
165
+ "add section below",
166
+ on_click=_add_section,
167
+ args=(model_card, key),
168
+ key=f"{key}.add",
169
+ )
170
+ with col_2:
171
+ st.button(
172
+ "add figure below",
173
+ on_click=_add_figure,
174
+ args=(model_card, key),
175
+ key=f"{key}.fig",
176
+ )
177
+
178
+
179
+ def display_sections(model_card: card.Card) -> None:
180
+ for key, section_name, content, is_fig in iterate_key_section_content(model_card._data):
181
+ create_form_from_section(model_card, key, section_name, content, is_fig)
182
+
183
+
184
+ def display_model_card(model_card: card.Card) -> None:
185
+ rendered = model_card.render()
186
+ metadata, rendered = process_card_for_rendering(rendered)
187
+
188
+ # strip metadata
189
+ with st.expander("show metadata"):
190
+ st.text(metadata)
191
+ st.markdown(rendered, unsafe_allow_html=True)
192
+
193
+
194
+ def reset_model_card() -> None:
195
+ if "task_state" not in st.session_state:
196
+ return
197
+ if "model_card" not in st.session_state:
198
+ del st.session_state["model_card"]
199
+
200
+ while st.session_state.task_state.done_list:
201
+ st.session_state.task_state.undo()
202
+
203
+
204
+ def delete_model_card() -> None:
205
+ if "model_card" in st.session_state:
206
+ del st.session_state["model_card"]
207
+ if "task_state" in st.session_state:
208
+ st.session_state.task_state.reset()
209
+
210
+
211
+ def undo_last():
212
+ st.session_state.task_state.undo()
213
+ display_model_card(st.session_state.model_card)
214
+
215
+
216
+ def redo_last():
217
+ st.session_state.task_state.redo()
218
+ display_model_card(st.session_state.model_card)
219
+
220
+
221
+ def add_download_model_card_button():
222
+ model_card = st.session_state.get("model_card")
223
+ download_disabled = not bool(model_card)
224
+ data = model_card.render()
225
+ st.download_button(
226
+ "Save (md)", data=data, disabled=download_disabled
227
+ )
228
+
229
+
230
+ def edit_input_form():
231
+ if "task_state" not in st.session_state:
232
+ st.session_state.task_state = TaskState()
233
+
234
+ with st.sidebar:
235
+ col_0, col_1, col_2, col_3, col_4 = st.columns([1.6, 1.5, 1.2, 2, 1.5])
236
+ undo_disabled = not bool(st.session_state.task_state.done_list)
237
+ redo_disabled = not bool(st.session_state.task_state.undone_list)
238
+ with col_0:
239
+ name = f"UNDO ({len(st.session_state.task_state.done_list)})"
240
+ st.button(name, on_click=undo_last, disabled=undo_disabled)
241
+ with col_1:
242
+ name = f"REDO ({len(st.session_state.task_state.undone_list)})"
243
+ st.button(name, on_click=redo_last, disabled=redo_disabled)
244
+ with col_2:
245
+ st.button("Reset", on_click=reset_model_card)
246
+ with col_3:
247
+ add_download_model_card_button()
248
+ with col_4:
249
+ st.button("Delete", on_click=delete_model_card)
250
+
251
+ if "model_card" in st.session_state:
252
+ display_sections(st.session_state.model_card)
253
+
254
+ if "model_card" in st.session_state:
255
+ display_model_card(st.session_state.model_card)
make-data.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # companion script to st-space-creator.py
2
+
3
+ import pickle
4
+
5
+ import pandas as pd
6
+ from sklearn.datasets import make_classification
7
+ from sklearn.linear_model import LogisticRegression
8
+ from sklearn.pipeline import Pipeline
9
+ from sklearn.preprocessing import StandardScaler
10
+
11
+ X, y = make_classification()
12
+ df = pd.DataFrame(X)
13
+
14
+ clf = Pipeline(
15
+ [
16
+ ("scale", StandardScaler()),
17
+ ("clf", LogisticRegression(random_state=0)),
18
+ ]
19
+ )
20
+ clf.fit(X, y)
21
+
22
+ with open("logreg.pkl", "wb") as f:
23
+ pickle.dump(clf, f)
24
+
25
+
26
+ df.to_csv("data.csv", index=False)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pandoc
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ catboost
2
+ huggingface_hub
3
+ lightgbm
4
+ pandas
5
+ scikit-learn
6
+ xgboost
7
+ git+https://github.com/skops-dev/skops.git
start.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import io
3
+ import os
4
+ import pickle
5
+ import shutil
6
+ from pathlib import Path
7
+ from tempfile import mkdtemp
8
+
9
+ import pandas as pd
10
+ import sklearn
11
+ import streamlit as st
12
+ from huggingface_hub import hf_hub_download
13
+ from sklearn.base import BaseEstimator
14
+ from sklearn.dummy import DummyClassifier
15
+
16
+ import skops.io as sio
17
+ from skops import card, hub_utils
18
+
19
+
20
+ hf_path = Path(mkdtemp(prefix="skops-")) # hf repo
21
+ tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files
22
+ description = """Create an sklearn model card
23
+
24
+ This Hugging Face Space that aims to provide a simple interface to use the `skops` model card creation utilities.
25
+
26
+ """
27
+
28
+
29
+
30
+ def load_model() -> None:
31
+ if st.session_state.get("model_file") is None:
32
+ st.session_state.model = DummyClassifier()
33
+ return
34
+
35
+ bytes_data = st.session_state.model_file.getvalue()
36
+ model = pickle.loads(bytes_data)
37
+ assert isinstance(model, BaseEstimator), "model must be an sklearn model"
38
+
39
+ st.session_state.model = model
40
+
41
+
42
+ def load_data() -> None:
43
+ if st.session_state.get("data_file"):
44
+ bytes_data = io.BytesIO(st.session_state.data_file.getvalue())
45
+ df = pd.read_csv(bytes_data)
46
+ else:
47
+ df = pd.DataFrame([])
48
+
49
+ st.session_state.data = df
50
+
51
+
52
+ def _clear_repo(path: str) -> None:
53
+ for file_path in glob.glob(str(Path(path) / "*")):
54
+ if os.path.isfile(file_path) or os.path.islink(file_path):
55
+ os.unlink(file_path)
56
+ elif os.path.isdir(file_path):
57
+ shutil.rmtree(file_path)
58
+
59
+
60
+ def init_repo(path: str) -> None:
61
+ _clear_repo(path)
62
+ requirements = []
63
+ task = "tabular-classification"
64
+ data = pd.DataFrame([])
65
+
66
+ if "requirements" in st.session_state:
67
+ requirements = st.session_state.requirements.splitlines()
68
+ if "task" in st.session_state:
69
+ task = st.session_state.task
70
+ if "data_file" in st.session_state:
71
+ load_data()
72
+ data = st.session_state.data
73
+
74
+ if task.startswith("text") and isinstance(data, pd.DataFrame):
75
+ data = data.values.tolist()
76
+
77
+ try:
78
+ file_name = tmp_path / "model.skops"
79
+ sio.dump(st.session_state.model, file_name)
80
+
81
+ hub_utils.init(
82
+ model=file_name,
83
+ dst=path,
84
+ task=task,
85
+ data=data,
86
+ requirements=requirements,
87
+ )
88
+ 1
89
+ except Exception as exc:
90
+ print("Uh oh, something went wrong when initializing the repo:", exc)
91
+
92
+
93
+ def create_skops_model_card() -> None:
94
+ init_repo(hf_path)
95
+ metadata = card.metadata_from_config(hf_path)
96
+ model_card = card.Card(model=st.session_state.model, metadata=metadata)
97
+ st.session_state.model_card = model_card
98
+
99
+
100
+ def create_empty_model_card() -> None:
101
+ init_repo(hf_path)
102
+ metadata = card.metadata_from_config(hf_path)
103
+ model_card = card.Card(model=st.session_state.model, metadata=metadata, template=None)
104
+ model_card.add(**{"Untitled": "[More Information Needed]"})
105
+ st.session_state.model_card = model_card
106
+
107
+
108
+ def create_hf_model_card() -> None:
109
+ repo_id = st.session_state.get("hf_repo_id", "").strip("'").strip('"')
110
+ if not repo_id:
111
+ return
112
+
113
+ print("downloading model card")
114
+ path = hf_hub_download(repo_id, "README.md")
115
+ model_card = card.parse_modelcard(path)
116
+ st.session_state.model_card = model_card
117
+
118
+
119
+ def start_input_form():
120
+ if "model" not in st.session_state:
121
+ st.session_state.model = DummyClassifier()
122
+
123
+ if "data" not in st.session_state:
124
+ st.session_state.data = pd.DataFrame([])
125
+
126
+ if "model_card" not in st.session_state:
127
+ st.session_state.model_card = None
128
+
129
+ st.markdown(description)
130
+ st.markdown("---")
131
+
132
+ st.text(
133
+ "Upload an sklearn model (strongly recommended)\n"
134
+ "The model can be used to automatically populate fields in the model card."
135
+ )
136
+ st.file_uploader("Upload a model*", on_change=load_model, key="model_file")
137
+ st.markdown("---")
138
+
139
+ st.text(
140
+ "Upload samples from your data (in csv format)\n"
141
+ "This sample data can be attached to the metadata of the model card"
142
+ )
143
+ st.file_uploader(
144
+ "Upload X data (csv)*", type=["csv"], on_change=load_data, key="data_file"
145
+ )
146
+ st.markdown("---")
147
+
148
+ st.selectbox(
149
+ label="Choose the task type*",
150
+ options=[
151
+ "tabular-classification",
152
+ "tabular-regression",
153
+ "text-classification",
154
+ "text-regression",
155
+ ],
156
+ key="task",
157
+ on_change=init_repo,
158
+ args=(hf_path,)
159
+ )
160
+ st.markdown("---")
161
+
162
+ st.text_area(
163
+ label="Requirements*",
164
+ value=f"scikit-learn=={sklearn.__version__}\n",
165
+ key="requirements",
166
+ on_change=init_repo,
167
+ args=(hf_path,)
168
+ )
169
+ st.markdown("---")
170
+
171
+ col_0, col_1, col_2 = st.columns([2, 2, 2])
172
+ with col_0:
173
+ st.button("Create a new skops model card", on_click=create_skops_model_card)
174
+
175
+ with col_1:
176
+ st.button("Create a new empty model card", on_click=create_empty_model_card)
177
+
178
+ with col_2:
179
+ with st.form("Load existing model card from HF Hub", clear_on_submit=False):
180
+ st.text_input("Repo name (e.g. 'gpt2')", key="hf_repo_id")
181
+ st.form_submit_button("Load", on_click=create_hf_model_card)
182
+
183
+
184
+ start_input_form()
tasks.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from uuid import uuid4
5
+
6
+ from skops import card
7
+ from skops.card._model_card import PlotSection, split_subsection_names
8
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
9
+
10
+
11
+ class Task:
12
+ def __init__(self, model_card: card.Card) -> None:
13
+ self.model_card = model_card
14
+
15
+ def do(self) -> None:
16
+ raise NotImplementedError
17
+
18
+ def undo(self) -> None:
19
+ raise NotImplementedError
20
+
21
+
22
+ class TaskState:
23
+ def __init__(self) -> None:
24
+ self.done_list: list[Task] = []
25
+ self.undone_list: list[Task] = []
26
+
27
+ def undo(self) -> None:
28
+ if not self.done_list:
29
+ return
30
+
31
+ task = self.done_list.pop(-1)
32
+ task.undo()
33
+ self.undone_list.append(task)
34
+
35
+ def redo(self) -> None:
36
+ if not self.undone_list:
37
+ return
38
+
39
+ task = self.undone_list.pop(-1)
40
+ task.do()
41
+ self.done_list.append(task)
42
+
43
+ def add(self, task: Task) -> None:
44
+ task.do()
45
+ self.done_list.append(task)
46
+ self.undone_list.clear()
47
+
48
+ def reset(self) -> None:
49
+ self.done_list.clear()
50
+ self.undone_list.clear()
51
+
52
+
53
+ class AddSectionTask(Task):
54
+ def __init__(
55
+ self,
56
+ model_card: card.Card,
57
+ title: str,
58
+ content: str,
59
+ ) -> None:
60
+ self.model_card = model_card
61
+ self.title = title
62
+ self.key = title + " " + str(uuid4())[:6]
63
+ self.content = content
64
+
65
+ def do(self) -> None:
66
+ self.model_card.add(**{self.key: self.content})
67
+ section = self.model_card.select(self.key)
68
+ section.title = split_subsection_names(self.title)[-1]
69
+
70
+ def undo(self) -> None:
71
+ self.model_card.delete(self.key)
72
+
73
+
74
+ class AddFigureTask(Task):
75
+ def __init__(
76
+ self,
77
+ model_card: card.Card,
78
+ title: str,
79
+ content: str,
80
+ ) -> None:
81
+ self.model_card = model_card
82
+ self.title = title
83
+ self.key = title + " " + str(uuid4())[:6]
84
+ self.content = content
85
+
86
+ def do(self) -> None:
87
+ self.model_card.add_plot(**{self.key: self.content})
88
+ section = self.model_card.select(self.key)
89
+ section.title = split_subsection_names(self.title)[-1]
90
+ section.is_fig = True # type: ignore
91
+
92
+ def undo(self) -> None:
93
+ self.model_card.delete(self.key)
94
+
95
+
96
+ class DeleteSectionTask(Task):
97
+ def __init__(
98
+ self,
99
+ model_card: card.Card,
100
+ key: str,
101
+ ) -> None:
102
+ self.model_card = model_card
103
+ self.key = key
104
+
105
+ def do(self) -> None:
106
+ self.model_card.select(self.key).visible = False
107
+
108
+ def undo(self) -> None:
109
+ self.model_card.select(self.key).visible = True
110
+
111
+
112
+ class UpdateSectionTask(Task):
113
+ def __init__(
114
+ self,
115
+ model_card: card.Card,
116
+ key: str,
117
+ old_name: str,
118
+ new_name: str,
119
+ old_content: str,
120
+ new_content: str,
121
+ ) -> None:
122
+ self.model_card = model_card
123
+ self.key = key
124
+ self.old_name = old_name
125
+ self.new_name = new_name
126
+ self.old_content = old_content
127
+ self.new_content = new_content
128
+
129
+ def do(self) -> None:
130
+ section = self.model_card.select(self.key)
131
+ new_title = split_subsection_names(self.new_name)[-1]
132
+ section.title = new_title
133
+ section.content = self.new_content
134
+
135
+ def undo(self) -> None:
136
+ section = self.model_card.select(self.key)
137
+ old_title = split_subsection_names(self.old_name)[-1]
138
+ section.title = old_title
139
+ section.content = self.old_content
140
+
141
+
142
+ class UpdateFigureTask(Task):
143
+ def __init__(
144
+ self,
145
+ model_card: card.Card,
146
+ key: str,
147
+ old_name: str,
148
+ new_name: str,
149
+ data: UploadedFile | None,
150
+ path: Path | None,
151
+ ) -> None:
152
+ self.model_card = model_card
153
+ self.key = key
154
+ self.old_name = old_name
155
+ self.new_name = new_name
156
+ self.old_data = self.model_card.select(self.key).content
157
+ self.path = path
158
+
159
+ if not data:
160
+ self.new_data = self.old_data
161
+ else:
162
+ self.new_data = data
163
+
164
+ def do(self) -> None:
165
+ section = self.model_card.select(self.key)
166
+ new_title = split_subsection_names(self.new_name)[-1]
167
+ section.title = self.title = new_title
168
+ if self.new_data == self.old_data: # image is same
169
+ return
170
+
171
+ # write figure
172
+ # note: this can still be the same image if the image is a file, there
173
+ # is no test to check, e.g., the hash of the image
174
+ with open(self.path, "wb") as f:
175
+ f.write(self.new_data.getvalue())
176
+ section.content = PlotSection(
177
+ alt_text=self.new_data.name,
178
+ path=self.path,
179
+ ).format()
180
+
181
+ def undo(self) -> None:
182
+ section = self.model_card.select(self.key)
183
+ old_title = split_subsection_names(self.old_name)[-1]
184
+ section.title = old_title
185
+ if self.new_data == self.old_data: # image is same
186
+ return
187
+
188
+ self.path.unlink(missing_ok=True)
189
+ section.content = self.old_data
utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Iterator
7
+
8
+ from skops.card._model_card import Section
9
+
10
+
11
+ def process_card_for_rendering(rendered: str) -> tuple[str, str]:
12
+ idx = rendered[1:].index("\n---") + 1
13
+ metadata = rendered[3:idx]
14
+ rendered = rendered[idx + 4 :] # noqa: E203
15
+
16
+ # below is a hack to display the images in streamlit
17
+ # https://discuss.streamlit.io/t/image-in-markdown/13274/10 The problem is
18
+
19
+ # that streamlit does not display images in markdown, so we need to replace
20
+ # them with html. However, we only want that in the rendered markdown, not
21
+ # in the card that is produced for the hub
22
+ def markdown_images(markdown):
23
+ # example image markdown:
24
+ # ![Test image](images/test.png "Alternate text")
25
+ images = re.findall(
26
+ r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))',
27
+ markdown,
28
+ )
29
+ return images
30
+
31
+ def img_to_bytes(img_path):
32
+ img_bytes = Path(img_path).read_bytes()
33
+ encoded = base64.b64encode(img_bytes).decode()
34
+ return encoded
35
+
36
+ def img_to_html(img_path, img_alt):
37
+ img_format = img_path.split(".")[-1]
38
+ img_html = (
39
+ f'<img src="data:image/{img_format.lower()};'
40
+ f'base64,{img_to_bytes(img_path)}" '
41
+ f'alt="{img_alt}" '
42
+ 'style="max-width: 100%;">'
43
+ )
44
+ return img_html
45
+
46
+ def markdown_insert_images(markdown):
47
+ images = markdown_images(markdown)
48
+
49
+ for image in images:
50
+ image_markdown = image[0]
51
+ image_alt = image[1]
52
+ image_path = image[2]
53
+ markdown = markdown.replace(
54
+ image_markdown, img_to_html(image_path, image_alt)
55
+ )
56
+ return markdown
57
+
58
+ rendered_with_img = markdown_insert_images(rendered)
59
+ return metadata, rendered_with_img
60
+
61
+
62
+ def iterate_key_section_content(
63
+ data: dict[str, Section],
64
+ parent_section: str = "",
65
+ parent_keys: list[str] | None = None,
66
+ ) -> Iterator[tuple[str, str, str, bool]]:
67
+ parent_keys = parent_keys or []
68
+
69
+ for key, val in data.items():
70
+ if parent_section:
71
+ title = "/".join((parent_section, val.title))
72
+ else:
73
+ title = val.title
74
+
75
+ if not val.visible:
76
+ continue
77
+
78
+ return_key = key if not parent_keys else "/".join(parent_keys + [key])
79
+ is_fig = getattr(val, "is_fig", False)
80
+ yield return_key, title, val.content, is_fig
81
+
82
+ if val.subsections:
83
+ yield from iterate_key_section_content(
84
+ val.subsections,
85
+ parent_section=title,
86
+ parent_keys=parent_keys + [key],
87
+ )