keeganskeate commited on
Commit
afe3140
1 Parent(s): fe49c28

Added current draft of the 🦨 SkunkFx Effects and Aromas Prediction Model

Browse files
Files changed (1) hide show
  1. effects_and_aromas.py +689 -0
effects_and_aromas.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reported Effects and Aromas Prediction Model
3
+ Copyright (c) 2022 Cannlytics
4
+
5
+ Authors: Keegan Skeate <https://github.com/keeganskeate>
6
+ Created: 5/13/2022
7
+ Updated: 6/1/2022
8
+ License: MIT License <https://opensource.org/licenses/MIT>
9
+
10
+ Description:
11
+
12
+ This methodology estimates the probability of a review containing
13
+ a specific aroma or effect. The methodology is then saved in
14
+ a re-usable model that can predict potential aromas and effects
15
+ given lab results for strains, flower products, etc.
16
+
17
+ Data Sources:
18
+
19
+ - Data from: Over eight hundred cannabis strains characterized
20
+ by the relationship between their subjective effects, perceptual
21
+ profiles, and chemical compositions
22
+ URL: <https://data.mendeley.com/datasets/6zwcgrttkp/1>
23
+ License: CC BY 4.0. <https://creativecommons.org/licenses/by/4.0/>
24
+
25
+ Resources:
26
+
27
+ - Over eight hundred cannabis strains characterized by the
28
+ relationship between their psychoactive effects, perceptual
29
+ profiles, and chemical compositions
30
+ URL: <https://www.biorxiv.org/content/10.1101/759696v1.abstract>
31
+
32
+ - Effects of cannabidiol in cannabis flower:
33
+ Implications for harm reduction
34
+ URL: <https://pubmed.ncbi.nlm.nih.gov/34467598/>
35
+
36
+ """
37
+ # Standard imports.
38
+ from datetime import datetime
39
+ import os
40
+ from typing import Any, Optional
41
+
42
+ # External imports.
43
+ from dotenv import dotenv_values
44
+ import pandas as pd
45
+
46
+ # Internal imports.
47
+ from cannlytics.firebase import (
48
+ initialize_firebase,
49
+ update_documents,
50
+ )
51
+ from cannlytics.stats import (
52
+ calculate_model_statistics,
53
+ estimate_discrete_model,
54
+ get_stats_model,
55
+ predict_stats_model,
56
+ upload_stats_model,
57
+ )
58
+ from cannlytics.utils import (
59
+ snake_case,
60
+ combine_columns,
61
+ nonzero_columns,
62
+ nonzero_rows,
63
+ sum_columns,
64
+ download_file_from_url,
65
+ unzip_files,
66
+ )
67
+
68
+ # Ignore convergence errors.
69
+ import warnings
70
+ from statsmodels.tools.sm_exceptions import ConvergenceWarning
71
+ warnings.simplefilter('ignore', ConvergenceWarning)
72
+ warnings.simplefilter('ignore', RuntimeWarning)
73
+
74
+
75
+ # Decarboxylation rate. Source: <https://www.conflabs.com/why-0-877/>
76
+ DECARB = 0.877
77
+
78
+ # TODO: It would be worthwhile to parse effects and aromas
79
+ # ourselves with NLP. Sometimes effects may be mentioned
80
+ # but not a negative. For example,"helped with my anxiety."
81
+
82
+
83
+ def download_strain_review_data(
84
+ data_dir: str,
85
+ url: Optional[str] = 'https://md-datasets-cache-zipfiles-prod.s3.eu-west-1.amazonaws.com/6zwcgrttkp-1.zip',
86
+ ):
87
+ """Download historic strain review data.
88
+ First, creates the data directory if it doesn't already exist.
89
+ Second, downloads the data to the given directory.
90
+ Third, unzips the data and returns the directories.
91
+ Source: "Data from: Over eight hundred cannabis strains characterized
92
+ by the relationship between their subjective effects, perceptual
93
+ profiles, and chemical compositions".
94
+ URL: <https://data.mendeley.com/datasets/6zwcgrttkp/1>
95
+ License: CC BY 4.0. <https://creativecommons.org/licenses/by/4.0/>
96
+ """
97
+ if not os.path.exists(data_dir):
98
+ os.makedirs(data_dir)
99
+ download_file_from_url(url, destination=data_dir)
100
+ unzip_files(data_dir)
101
+ # Optional: Get the directories programmatically.
102
+ strain_folder = 'Strain data/strains'
103
+ compound_folder = 'Terpene and Cannabinoid data'
104
+ return {'strains': strain_folder, 'compounds': compound_folder}
105
+
106
+
107
+ def curate_lab_results(
108
+ data_dir: str,
109
+ compound_folder: Optional[str] = 'Terpene and Cannabinoid data',
110
+ cannabinoid_file: Optional[str] = 'rawDATACana',
111
+ terpene_file: Optional[str] = 'rawDATATerp',
112
+ max_cannabinoids: Optional[int] = 35,
113
+ max_terpenes: Optional[int] = 8,
114
+ ):
115
+ """Curate lab results for effects prediction model.
116
+ Args:
117
+ data_dir (str): The data where the raw lab results live.
118
+ compound_folder (str): The folder where the cannabinoid and terpene data live.
119
+ cannabinoid_file (str): The name of the raw cannabinoid text file.
120
+ terpene_file (str): The name of the raw terpene text file.
121
+ max_cannabinoids (int): The maximum value for permissible cannabinoid tests.
122
+ max_terpenes (int): The maximum value for permissible terpene tests.
123
+ Returns:
124
+ (DataFrame): Returns the lab results.
125
+ """
126
+
127
+ # Rename any oddly named columns.
128
+ rename = {
129
+ 'cb_da': 'cbda',
130
+ 'cb_ga': 'cbda',
131
+ 'delta_9_th_ca': 'delta_9_thca',
132
+ 'th_ca': 'thca',
133
+ 'caryophylleneoxide': 'caryophyllene_oxide',
134
+ '3_carene': 'carene',
135
+ }
136
+
137
+ # Read terpenes.
138
+ terpenes = None
139
+ if terpene_file:
140
+ file_path = os.path.join(data_dir, compound_folder, terpene_file)
141
+ terpenes = pd.read_csv(file_path, index_col=0)
142
+ terpenes.columns = [snake_case(x).strip('x_') for x in terpenes.columns]
143
+ terpenes.rename(columns=rename, inplace=True)
144
+ terpene_names = list(terpenes.columns[3:])
145
+ compounds = terpenes
146
+
147
+ # Read cannabinoids.
148
+ cannabinoids = None
149
+ if cannabinoid_file:
150
+ file_path = os.path.join(data_dir, compound_folder, cannabinoid_file)
151
+ cannabinoids = pd.read_csv(file_path, index_col=0)
152
+ cannabinoids.columns = [snake_case(x).strip('x_') for x in cannabinoids.columns]
153
+ cannabinoids.rename(columns=rename, inplace=True)
154
+ cannabinoid_names = list(cannabinoids.columns[3:])
155
+ compounds = cannabinoids
156
+
157
+ # Merge terpenes and cannabinoids.
158
+ if terpene_file and cannabinoid_file:
159
+ compounds = pd.merge(
160
+ left=cannabinoids,
161
+ right=terpenes,
162
+ left_on='file',
163
+ right_on='file',
164
+ how='left',
165
+ suffixes=['', '_terpene']
166
+ )
167
+
168
+ # Combine identical cannabinoids.
169
+ compounds = combine_columns(compounds, 'thca', 'delta_9_thca')
170
+ cannabinoid_names.remove('delta_9_thca')
171
+
172
+ # Combine identical terpenes.
173
+ compounds = combine_columns(compounds, 'p_cymene', 'pcymene')
174
+ compounds = combine_columns(compounds, 'beta_caryophyllene', 'caryophyllene')
175
+ compounds = combine_columns(compounds, 'humulene', 'alpha_humulene')
176
+ terpene_names.remove('pcymene')
177
+ terpene_names.remove('caryophyllene')
178
+ terpene_names.remove('alpha_humulene')
179
+
180
+ # Sum ocimene.
181
+ analytes = ['ocimene', 'beta_ocimene', 'trans_ocimene']
182
+ compounds = sum_columns(compounds, 'ocimene', analytes, drop=False)
183
+ compounds.drop(columns=['beta_ocimene', 'trans_ocimene'], inplace=True)
184
+ terpene_names.remove('beta_ocimene')
185
+ terpene_names.remove('trans_ocimene')
186
+
187
+ # Sum nerolidol.
188
+ analytes = ['trans_nerolidol', 'cis_nerolidol', 'transnerolidol_1',
189
+ 'transnerolidol_2']
190
+ compounds = sum_columns(compounds, 'nerolidol', analytes)
191
+ terpene_names.remove('trans_nerolidol')
192
+ terpene_names.remove('cis_nerolidol')
193
+ terpene_names.remove('transnerolidol_1')
194
+ terpene_names.remove('transnerolidol_2')
195
+ terpene_names.append('nerolidol')
196
+
197
+ # Code missing values as 0.
198
+ compounds = compounds.fillna(0)
199
+
200
+ # Calculate totals.
201
+ compounds['total_terpenes'] = compounds[terpene_names].sum(axis=1).round(2)
202
+ compounds['total_cannabinoids'] = compounds[cannabinoid_names].sum(axis=1).round(2)
203
+
204
+ # Calculate total THC, CBD, and CBG.
205
+ # TODO: Optimize?
206
+ compounds.loc[compounds['thca'] == 0, 'total_thc'] = compounds['delta_9_thc'].round(2)
207
+ compounds.loc[compounds['thca'] != 0, 'total_thc'] = (compounds['delta_9_thc'] + compounds['thca'].mul(DECARB)).round(2)
208
+ compounds.loc[compounds['cbda'] == 0, 'total_cbd'] = compounds['cbd'].round(2)
209
+ compounds.loc[compounds['cbda'] != 0, 'total_cbd'] = (compounds['cbd'] + compounds['cbda'].mul(DECARB)).round(2)
210
+ compounds.loc[compounds['cbga'] == 0, 'total_cbg'] = compounds['cbg'].round(2)
211
+ compounds.loc[compounds['cbga'] != 0, 'total_cbg'] = (compounds['cbg'] + compounds['cbga'].mul(DECARB)).round(2)
212
+
213
+ # Calculate terpinenes total.
214
+ analytes = ['alpha_terpinene', 'gamma_terpinene', 'terpinolene', 'terpinene']
215
+ compounds = sum_columns(compounds, 'terpinenes', analytes, drop=False)
216
+
217
+ # Exclude outliers.
218
+ compounds = compounds.loc[
219
+ (compounds['total_cannabinoids'] < max_cannabinoids) &
220
+ (compounds['total_terpenes'] < max_terpenes)
221
+ ]
222
+
223
+ # Clean and return the data.
224
+ extraneous = ['type', 'file', 'tag_terpene', 'type_terpene']
225
+ compounds.drop(columns=extraneous, inplace=True)
226
+ compounds.rename(columns={'tag': 'strain_name'}, inplace=True)
227
+ compounds['strain_name'] = compounds['strain_name'].str.replace('-', ' ').str.title()
228
+ return compounds
229
+
230
+
231
+ def curate_strain_reviews(
232
+ data_dir: str,
233
+ results: Any,
234
+ strain_folder: Optional[str] = 'Strain data/strains',
235
+ ):
236
+ """Curate cannabis strain reviews.
237
+ Args:
238
+ data_dir (str): The directory where the data lives.
239
+ results (DataFrame): The curated lab result data.
240
+ strain_folder (str): The folder where the review data lives.
241
+ Returns:
242
+ (DataFrame): Returns the strain reviews.
243
+ """
244
+
245
+ # Create a panel of reviews of strain lab results.
246
+ panel = pd.DataFrame()
247
+ for _, row in results.iterrows():
248
+
249
+ # Read the strain's effects and aromas data.
250
+ review_file = row.name.lower().replace(' ', '-') + '.p'
251
+ file_path = os.path.join(data_dir, strain_folder, review_file)
252
+ try:
253
+ strain = pd.read_pickle(file_path)
254
+ except FileNotFoundError:
255
+ print("Couldn't find:", file_path)
256
+ continue
257
+
258
+ # Assign dummy variables for effects and aromas.
259
+ reviews = strain['data_strain']
260
+ name = strain['strain']
261
+ category = list(strain['categorias'])[0]
262
+ for n, review in enumerate(reviews):
263
+
264
+ # Create panel observation, combining prior compound data.
265
+ obs = row.copy()
266
+ for aroma in review['sabores']:
267
+ key = 'aroma_' + snake_case(aroma)
268
+ obs[key] = 1
269
+ for effect in review['efectos']:
270
+ key = 'effect_' + snake_case(effect)
271
+ obs[key] = 1
272
+
273
+ # Assign category determined from original authors NLP.
274
+ obs['category'] = category
275
+ obs['strain_name'] = row.name
276
+ obs['review'] = review['reporte']
277
+ obs['user'] = review['usuario']
278
+
279
+ # Record the observation.
280
+ obs.name = name + '-' + str(n)
281
+ obs = obs.to_frame().transpose()
282
+ panel = pd.concat([panel, obs])
283
+
284
+ # Return the panel with null effects and aromas coded as 0.
285
+ return panel.fillna(0)
286
+
287
+
288
+ def download_dataset(name, destination):
289
+ """Download a Cannlytics dataset by its name and given a destination.
290
+ Args:
291
+ name (str): A dataset short name.
292
+ destination (str): The path to download the data for it to live.
293
+ """
294
+ short_url = f'https://cannlytics.page.link/{name}'
295
+ download_file_from_url(short_url, destination=destination)
296
+
297
+
298
+ #-----------------------------------------------------------------------
299
+ # Tests
300
+ #-----------------------------------------------------------------------
301
+
302
+ if __name__ == '__main__':
303
+
304
+ #-------------------------------------------------------------------
305
+ # Curate the strain lab result data.
306
+ #-------------------------------------------------------------------
307
+
308
+ print('Testing...')
309
+ DATA_DIR = '../../../.datasets/subjective-effects'
310
+
311
+ # Optional: Download the original data.
312
+ # download_strain_review_data(DATA_DIR)
313
+
314
+ # Curate the lab results.
315
+ print('Curating strain lab results...')
316
+ results = curate_lab_results(DATA_DIR)
317
+
318
+ # Average results by strain, counting the number of tests per strain.
319
+ strain_data = results.groupby('strain_name').mean()
320
+ strain_data['tests'] = results.groupby('strain_name')['cbd'].count()
321
+ strain_data['strain_name'] = strain_data.index
322
+
323
+ # Save the lab results and strain data.
324
+ # today = datetime.now().isoformat()[:10]
325
+ # results.to_excel(DATA_DIR + f'/psi-labs-results-{today}.xlsx')
326
+ # strain_data.to_excel(DATA_DIR + f'/strain-avg-results-{today}.xlsx')
327
+
328
+ #-------------------------------------------------------------------
329
+
330
+ # # Initialize Firebase.
331
+ # env_file = '../../../.env'
332
+ # config = dotenv_values(env_file)
333
+ # bucket_name = config['FIREBASE_STORAGE_BUCKET']
334
+ # db = initialize_firebase(
335
+ # env_file=env_file,
336
+ # bucket_name=bucket_name,
337
+ # )
338
+
339
+ # Upload the strain data to Firestore.
340
+ # docs = strain_data.to_dict(orient='records')
341
+ # refs = [f'public/data/strains/{x}' for x in strain_data.index]
342
+ # update_documents(refs, docs, database=db)
343
+ # print('Updated %i strains.' % len(docs))
344
+
345
+ # Upload individual lab results for each strain.
346
+ # Future work: Format the lab results as metrics with CAS, etc.
347
+ # results['id'] = results.index
348
+ # results['lab_id'] = 'SC-000005'
349
+ # results['lab_name'] = 'PSI Labs'
350
+ # docs = results.to_dict(orient='records')
351
+ # refs = [f'public/data/strains/{x[0]}/strain_lab_results/lab_result_{x[1]}' for x in results[['strain_name', 'id']].values]
352
+ # update_documents(refs, docs, database=db)
353
+ # print('Updated %i strain lab results.' % len(docs))
354
+
355
+ #-------------------------------------------------------------------
356
+ # Curate the strain review data.
357
+ #-------------------------------------------------------------------
358
+
359
+ # # Curate the reviews.
360
+ print('Curating reviews...')
361
+ reviews = curate_strain_reviews(DATA_DIR, strain_data)
362
+
363
+ # Combine `effect_anxiety` and `effect_anxious`.
364
+ reviews = combine_columns(reviews, 'effect_anxious', 'effect_anxiety')
365
+
366
+ # # Optional: Save and read back in the reviews.
367
+ today = datetime.now().isoformat()[:10]
368
+ datafile = DATA_DIR + f'/strain-reviews-{today}.xlsx'
369
+ reviews.to_excel(datafile)
370
+
371
+ # datafile = DATA_DIR + '/strain-reviews-2022-06-01.xlsx'
372
+ # reviews = pd.read_excel(datafile, index_col=0)
373
+
374
+ # # Optional: Upload strain review data to Firestore.
375
+ # reviews['id'] = reviews.index
376
+ # docs = reviews.to_dict(orient='records')
377
+ # refs = [f'public/data/strains/{x[0]}/strain_reviews/strain_review_{x[1]}' for x in reviews[['strain_name', 'id']].values]
378
+ # # update_documents(refs, docs, database=db)
379
+
380
+ #-------------------------------------------------------------------
381
+
382
+ # Future work: Programmatically upload the datasets to Storage.
383
+
384
+ # Optional: Download the pre-compiled data from Cannlytics.
385
+ # strain_data = download_dataset('strains', DATA_DIR)
386
+ # reviews = download_dataset('strain-reviews', DATA_DIR)
387
+
388
+ #-------------------------------------------------------------------
389
+ # Fit the model with the training data.
390
+ #-------------------------------------------------------------------
391
+
392
+ # Specify different prediction models.
393
+ # Future work: Logit, cannabinoid / terpene ratios, and bayesian models.
394
+ # Handle `minor` cannabinoids in `totals` and perhaps `simple` models
395
+ # (i.e. `total_cannabinoids` - `total_thc` - `total_cbd`).
396
+ variates = {
397
+ 'full': [
398
+ 'delta_9_thc',
399
+ 'cbd',
400
+ 'cbn',
401
+ 'cbg',
402
+ 'cbc',
403
+ 'thcv',
404
+ 'cbda',
405
+ 'delta_8_thc',
406
+ 'cbga',
407
+ 'thca',
408
+ 'd_limonene',
409
+ 'beta_myrcene',
410
+ 'beta_pinene',
411
+ 'linalool',
412
+ 'alpha_pinene',
413
+ 'camphene',
414
+ 'carene',
415
+ 'alpha_terpinene',
416
+ 'ocimene',
417
+ 'eucalyptol',
418
+ 'gamma_terpinene',
419
+ 'terpinolene',
420
+ 'isopulegol',
421
+ 'geraniol',
422
+ 'humulene',
423
+ 'guaiol',
424
+ 'caryophyllene_oxide',
425
+ 'alpha_bisabolol',
426
+ 'beta_caryophyllene',
427
+ 'p_cymene',
428
+ 'terpinene',
429
+ 'nerolidol',
430
+ ],
431
+ 'terpene_only': [
432
+ 'd_limonene',
433
+ 'beta_myrcene',
434
+ 'beta_pinene',
435
+ 'linalool',
436
+ 'alpha_pinene',
437
+ 'camphene',
438
+ 'carene',
439
+ 'alpha_terpinene',
440
+ 'ocimene',
441
+ 'eucalyptol',
442
+ 'gamma_terpinene',
443
+ 'terpinolene',
444
+ 'isopulegol',
445
+ 'geraniol',
446
+ 'humulene',
447
+ 'guaiol',
448
+ 'caryophyllene_oxide',
449
+ 'alpha_bisabolol',
450
+ 'beta_caryophyllene',
451
+ 'p_cymene',
452
+ 'terpinene',
453
+ 'nerolidol',
454
+ ],
455
+ 'cannabinoid_only': [
456
+ 'delta_9_thc',
457
+ 'cbd',
458
+ 'cbn',
459
+ 'cbg',
460
+ 'cbc',
461
+ 'thcv',
462
+ 'cbda',
463
+ 'delta_8_thc',
464
+ 'cbga',
465
+ 'thca',
466
+ ],
467
+ 'totals': [
468
+ 'total_terpenes',
469
+ 'total_thc',
470
+ 'total_cbd',
471
+ ],
472
+ 'simple': [
473
+ 'total_thc',
474
+ 'total_cbd',
475
+ ],
476
+ }
477
+
478
+ # # Use the data to create an effect prediction model.
479
+ # model_name = 'full'
480
+ # aromas = [x for x in reviews.columns if x.startswith('aroma')]
481
+ # effects = [x for x in reviews.columns if x.startswith('effect')]
482
+ # Y = reviews[aromas + effects]
483
+ # X = reviews[variates[model_name]]
484
+ # print('Estimating model:', model_name)
485
+ # effects_model = estimate_discrete_model(X, Y)
486
+
487
+ # # Calculate statistics for the model.
488
+ # model_stats = calculate_model_statistics(effects_model, Y, X)
489
+
490
+ # # Look at the expected probability of an informed decision.
491
+ # stat = 'informedness'
492
+ # print(
493
+ # f'Mean {stat}:',
494
+ # round(model_stats.loc[model_stats[stat] < 1][stat].mean(), 4)
495
+ # )
496
+
497
+ # # Save the model.
498
+ # ref = f'public/models/effects/{model_name}'
499
+ # model_data = upload_stats_model(
500
+ # effects_model,
501
+ # ref,
502
+ # name=model_name,
503
+ # stats=model_stats,
504
+ # data_dir=DATA_DIR,
505
+ # )
506
+ # print('Effects prediction model saved:', ref)
507
+
508
+ #-------------------------------------------------------------------
509
+ # Optional: Use the model to predict the sample and save the
510
+ # predictions for easy access in the future.
511
+ #-------------------------------------------------------------------
512
+
513
+ # # Optional: Save the official strain predictions.
514
+ # predictions = predict_stats_model(effects_model, X, model_stats['threshold'])
515
+ # predicted_effects = predictions.apply(nonzero_rows, axis=1)
516
+ # strain_effects = predicted_effects.to_frame()
517
+ # strain_effects['strain_name'] = reviews['strain_name']
518
+ # strain_effects = strain_effects.groupby('strain_name').first()
519
+ # refs = [f'public/data/strains/{x}' for x in strain_effects.index]
520
+ # docs = [{
521
+ # 'predicted_effects': [y for y in x[0] if y.startswith('effect')],
522
+ # 'predicted_aromas': [y for y in x[0] if y.startswith('aroma')],
523
+ # } for x in strain_effects.values]
524
+ # for i, doc in enumerate(docs):
525
+ # stats = {}
526
+ # outcomes = doc['predicted_effects'] + doc['predicted_aromas']
527
+ # for outcome in outcomes:
528
+ # stats[outcome] = model_stats.loc[outcome].to_dict()
529
+ # docs[i]['model_stats'] = stats
530
+ # docs[i]['model'] = model_name
531
+ # update_documents(refs, docs)
532
+ # print('Updated %i strain predictions.' % len(docs))
533
+
534
+ #-------------------------------------------------------------------
535
+ # How to use the model in the wild: `full` model.
536
+ #-------------------------------------------------------------------
537
+
538
+ # # 1. Get the model and its statistics.
539
+ # model_name = 'full'
540
+ # model_ref = f'public/models/effects/{model_name}'
541
+ # model_data = get_stats_model(model_ref, data_dir=DATA_DIR)
542
+ # model_stats = model_data['model_stats']
543
+ # models = model_data['model']
544
+ # thresholds = model_stats['threshold']
545
+
546
+ # # 2. Predict a single sample (below are mean concentrations).
547
+ # strain_name = 'Test Sample'
548
+ # x = pd.DataFrame([{
549
+ # 'delta_9_thc': 10.85,
550
+ # 'cbd': 0.29,
551
+ # 'cbn': 0.06,
552
+ # 'cbg': 0.54,
553
+ # 'cbc': 0.15,
554
+ # 'thcv': 0.07,
555
+ # 'cbda': 0.40,
556
+ # 'delta_8_thc': 0.00,
557
+ # 'cbga': 0.40,
558
+ # 'thca': 8.64,
559
+ # 'd_limonene': 0.22,
560
+ # 'beta_ocimene': 0.05,
561
+ # 'beta_myrcene': 0.35,
562
+ # 'beta_pinene': 0.12,
563
+ # 'linalool': 0.07,
564
+ # 'alpha_pinene': 0.10,
565
+ # 'camphene': 0.01,
566
+ # 'carene': 0.00,
567
+ # 'alpha_terpinene': 0.00,
568
+ # 'ocimene': 0.00,
569
+ # 'cymene': 0.00,
570
+ # 'eucalyptol': 0.00,
571
+ # 'gamma_terpinene': 0.00,
572
+ # 'terpinolene': 0.80,
573
+ # 'isopulegol': 0.00,
574
+ # 'geraniol': 0.00,
575
+ # 'humulene': 0.06,
576
+ # 'nerolidol': 0.01,
577
+ # 'guaiol': 0.01,
578
+ # 'caryophyllene_oxide': 0.00,
579
+ # 'alpha_bisabolol': 0.03,
580
+ # 'beta_caryophyllene': 0.18,
581
+ # 'alpha_humulene': 0.03,
582
+ # 'p_cymene': 0.00,
583
+ # 'terpinene': 0.00,
584
+ # }])
585
+ # prediction = predict_stats_model(models, x, thresholds)
586
+ # outcomes = nonzero_columns(prediction)
587
+ # effects = [x for x in outcomes if x.startswith('effect')]
588
+ # aromas = [x for x in outcomes if x.startswith('aroma')]
589
+ # print(f'Predicted effects:', effects)
590
+ # print(f'Predicted aromas:', aromas)
591
+
592
+ # # 3. Save / log the prediction and model stats.
593
+ # timestamp = datetime.now().isoformat()[:19]
594
+ # data = {
595
+ # 'predicted_effects': effects,
596
+ # 'predicted_aromas': aromas,
597
+ # 'lab_results': x.to_dict(orient='records')[0],
598
+ # 'strain_name': strain_name,
599
+ # 'timestamp': timestamp,
600
+ # 'model': model_name,
601
+ # 'model_stats': model_stats,
602
+ # }
603
+ # ref = 'models/effects/model_predictions/%s' % (timestamp.replace(':', '-'))
604
+ # update_documents([ref], [data])
605
+
606
+ #-------------------------------------------------------------------
607
+ # How to use the model in the wild: `simple` model.
608
+ #-------------------------------------------------------------------
609
+
610
+ # # 1. Get the model and its statistics.
611
+ # model_name = 'simple'
612
+ # model_ref = f'public/models/effects/{model_name}'
613
+ # model_data = get_stats_model(model_ref, data_dir=DATA_DIR)
614
+ # model_stats = model_data['model_stats']
615
+ # models = model_data['model']
616
+ # thresholds = model_stats['threshold']
617
+
618
+ # # 2. Predict samples.
619
+ # x = pd.DataFrame([
620
+ # {'total_cbd': 1.8, 'total_thc': 18.0},
621
+ # {'total_cbd': 1.0, 'total_thc': 20.0},
622
+ # {'total_cbd': 1.0, 'total_thc': 30.0},
623
+ # {'total_cbd': 7.0, 'total_thc': 7.0},
624
+ # ])
625
+ # prediction = predict_stats_model(models, x, thresholds)
626
+ # outcomes = pd.DataFrame()
627
+ # for index, row in prediction.iterrows():
628
+ # print(f'\nSample {index}')
629
+ # print('-----------------')
630
+ # for i, key in enumerate(row['predicted_effects']):
631
+ # tpr = round(model_stats['true_positive_rate'][key] * 100, 2)
632
+ # fpr = round(model_stats['false_positive_rate'][key] * 100, 2)
633
+ # title = key.replace('effect_', '').replace('_', ' ').title()
634
+ # print(title, f'(TPR: {tpr}%, FPR: {fpr}%)')
635
+ # outcomes = pd.concat([outcomes, pd.DataFrame([{
636
+ # 'tpr': tpr,
637
+ # 'fpr': fpr,
638
+ # 'name': title,
639
+ # 'strain_name': index,
640
+ # }])])
641
+
642
+ #-------------------------------------------------------------------
643
+ # Example visualization of the predicted outcomes.
644
+ #-------------------------------------------------------------------
645
+
646
+ # # Setup plotting style.
647
+ # import seaborn as sns
648
+ # import matplotlib.pyplot as plt
649
+ # import matplotlib.patches as mpatches
650
+ # plt.style.use('fivethirtyeight')
651
+ # plt.rcParams.update({
652
+ # 'font.family': 'Times New Roman',
653
+ # })
654
+
655
+ # # Create the plot.
656
+ # outcomes.sort_values('tpr', ascending=False, inplace=True)
657
+ # colors = sns.color_palette('Spectral', n_colors=12)
658
+ # colors = [colors[x] for x in [9, 3, 1, 10]]
659
+ # sns.catplot(
660
+ # x='name',
661
+ # y='tpr',
662
+ # hue='strain_name',
663
+ # data=outcomes,
664
+ # kind='bar',
665
+ # legend=False,
666
+ # aspect=12/8,
667
+ # palette=colors,
668
+ # )
669
+ # handles = []
670
+ # ratios = ['10:1', '20:1', '30:1', '1:1']
671
+ # for i, ratio in enumerate(ratios):
672
+ # patch = mpatches.Patch(color=colors[i], label=ratio)
673
+ # handles.append(patch)
674
+ # plt.legend(
675
+ # loc='upper right',
676
+ # title='THC:CBD',
677
+ # handles=handles,
678
+ # )
679
+ # plt.title('Predicted Effects That May be Reported')
680
+ # plt.ylabel('True Positive Rate')
681
+ # plt.xlabel('Predicted Effect')
682
+ # plt.xticks(rotation=90)
683
+ # plt.show()
684
+
685
+ #-------------------------------------------------------------------
686
+ # Fin.
687
+ #-------------------------------------------------------------------
688
+
689
+ print('Test finished.')