Update apis/chat_api.py
Browse files- apis/chat_api.py +16 -21
apis/chat_api.py
CHANGED
@@ -188,35 +188,30 @@ class ChatAPIApp:
|
|
188 |
data_response = streamer.chat_return_dict(stream_response)
|
189 |
return data_response
|
190 |
|
191 |
-
async def embedding(request: QueryRequest, api_key: str = Depends(extract_api_key)):
|
192 |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model}"
|
193 |
headers = {"Authorization": f"Bearer {api_key}"}
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
raise HTTPException(status_code=500, detail="Unexpected response format.")
|
206 |
-
|
207 |
-
# Assuming each embedding is a list of lists of floats, flatten it
|
208 |
-
flattened_embeddings = [sum(embedding, []) for embedding in result]
|
209 |
data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)]
|
210 |
-
|
211 |
return EmbeddingResponse(
|
212 |
object="list",
|
213 |
data=data,
|
214 |
-
model=request.
|
215 |
-
usage={"prompt_tokens": len(request.
|
216 |
)
|
217 |
-
|
218 |
-
logging.error(f"
|
219 |
-
raise HTTPException(status_code=500, detail=
|
220 |
|
221 |
def setup_routes(self):
|
222 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
|
|
188 |
data_response = streamer.chat_return_dict(stream_response)
|
189 |
return data_response
|
190 |
|
191 |
+
async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
|
192 |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model}"
|
193 |
headers = {"Authorization": f"Bearer {api_key}"}
|
194 |
|
195 |
+
response = requests.post(api_url, headers=headers, json={"inputs": request.input})
|
196 |
+
result = response.json()
|
197 |
+
|
198 |
+
if "error" in result:
|
199 |
+
logging.error(f"Error from Hugging Face API: {result['error']}")
|
200 |
+
error_detail = result.get('error', 'No detailed error message provided.')
|
201 |
+
raise HTTPException(status_code=503, detail=f"The model is currently loading, please re-run the query. Detail: {error_detail}")
|
202 |
+
|
203 |
+
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
204 |
+
flattened_embeddings = [item for sublist in result for item in sublist] # Flatten list of lists
|
|
|
|
|
|
|
|
|
205 |
data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)]
|
|
|
206 |
return EmbeddingResponse(
|
207 |
object="list",
|
208 |
data=data,
|
209 |
+
model=request.model,
|
210 |
+
usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
|
211 |
)
|
212 |
+
else:
|
213 |
+
logging.error(f"Unexpected response format: {result}")
|
214 |
+
raise HTTPException(status_code=500, detail="Unexpected response format.")
|
215 |
|
216 |
def setup_routes(self):
|
217 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|