arthur-stackadoc-com commited on
Commit
760fae4
1 Parent(s): f426af8

requirements.txt : added

Browse files
Files changed (3) hide show
  1. app.py +26 -9
  2. requirements.txt +7 -0
  3. settings.py +16 -0
app.py CHANGED
@@ -1,22 +1,39 @@
1
  import gradio as gr
 
 
2
 
 
3
 
4
- class Hit():
5
- def __init__(self):
6
- self.score = 1
7
- self.payload = {
8
- "audio_path": "https://synthia-research.s3.amazonaws.com/music_db/7fbb0c4de0e4bdf5e1dec8f3e803174b.mp3"
9
- }
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def sound_search(query):
13
- hits = [Hit() for _ in range(3)]
 
 
 
 
 
14
  return [
15
  gr.Audio(
16
- hit.payload['audio_path'],
17
  label=f"score: {hit.score}")
18
  for hit in hits
19
- ]
20
 
21
 
22
  with gr.Blocks() as demo:
 
1
  import gradio as gr
2
+ import laion_clap
3
+ from qdrant_client import QdrantClient
4
 
5
+ from settings import QDRANT_KEY, QDRANT_URL, ENVIRONMENT
6
 
7
+ # Loading the Qdrant DB in local ###################################################################
8
+ if ENVIRONMENT == "PROD":
9
+ qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_KEY)
10
+ else:
11
+ qdrant_client = QdrantClient("localhost", port=6333)
12
+ print("[INFO] Client created...")
13
+
14
+ # loading the model
15
+ print("[INFO] Loading the model...")
16
+ model_name = "laion/larger_clap_music"
17
+ model = laion_clap.CLAP_Module(enable_fusion=False)
18
+ model.load_ckpt() # download the default pretrained checkpoint.
19
+
20
+ # Gradio Interface #################################################################################
21
+ max_results = 10
22
 
23
 
24
  def sound_search(query):
25
+ text_embed = model.get_text_embedding([query, ''])[0] # trick because can't accept singleton
26
+ hits = qdrant_client.search(
27
+ collection_name="music_db",
28
+ query_vector=text_embed,
29
+ limit=max_results,
30
+ )
31
  return [
32
  gr.Audio(
33
+ hit.payload['s3_url'],
34
  label=f"score: {hit.score}")
35
  for hit in hits
36
+ ] * 3
37
 
38
 
39
  with gr.Blocks() as demo:
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ torch
3
+ gradio
4
+ laion_clap
5
+ qdrant_client
6
+ torchvision
7
+ transformers==4.30.0
settings.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as op
3
+ from pathlib import Path
4
+
5
+ from dotenv import load_dotenv
6
+
7
+ BASE_DIR_PATH = Path(op.dirname(Path(__file__).resolve()))
8
+ load_dotenv(BASE_DIR_PATH / ".env")
9
+
10
+ # QDrant
11
+ QDRANT_KEY = os.environ.get('QDRANT_KEY')
12
+ QDRANT_URL = os.environ.get('QDRANT_URL')
13
+ QDRANT_COLLECTION_NAME = 'music_db'
14
+
15
+ # Environment
16
+ ENVIRONMENT = "PROD" if os.environ.get('USERNAME') == "arthur" else "PROD"