ka1kuk commited on
Commit
0e70dd6
1 Parent(s): d25ec2d

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +20 -22
apis/chat_api.py CHANGED
@@ -188,36 +188,34 @@ class ChatAPIApp:
188
  data_response = streamer.chat_return_dict(stream_response)
189
  return data_response
190
 
191
- async def chat_embedding(self, input_text: str, model_name: str, api_key: str):
192
- api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
193
  headers = {"Authorization": f"Bearer {api_key}"}
194
- response = requests.post(api_url, headers=headers, json={"inputs": input_text})
195
- result = response.json()
 
 
196
 
197
- # Improved error handling and logging
198
- if "error" in result:
199
- logging.error(f"Error from Hugging Face API: {result['error']}")
200
- # More detailed error message
201
- error_detail = result.get('error', 'No detailed error message provided.')
202
- raise RuntimeError(f"The model is currently loading, please re-run the query. Detail: {error_detail}")
 
 
 
 
 
203
 
204
- if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
205
- return [item for sublist in result for item in sublist] # Flatten list of lists
206
- else:
207
- logging.error(f"Unexpected response format: {result}")
208
- raise RuntimeError("Unexpected response format.")
209
-
210
- async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
211
- try:
212
- embeddings = await self.chat_embedding(request.input, request.model, api_key)
213
- data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(embeddings)]
214
  return EmbeddingResponse(
215
  object="list",
216
  data=data,
217
- model=request.model,
218
- usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
219
  )
220
  except Exception as e:
 
221
  raise HTTPException(status_code=500, detail=str(e))
222
 
223
  def setup_routes(self):
 
188
  data_response = streamer.chat_return_dict(stream_response)
189
  return data_response
190
 
191
+ async def embedding(request: QueryRequest, api_key: str = Depends(extract_api_key)):
192
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model_name}"
193
  headers = {"Authorization": f"Bearer {api_key}"}
194
+
195
+ try:
196
+ response = requests.post(api_url, headers=headers, json={"inputs": request.input_text})
197
+ result = response.json()
198
 
199
+ if "error" in result:
200
+ logging.error(f"Error from Hugging Face API: {result.get('error', 'No detailed error message provided.')}")
201
+ raise HTTPException(status_code=503, detail="The model is currently loading, please re-run the query.")
202
+
203
+ if not (isinstance(result, list) and len(result) > 0 and isinstance(result[0], list)):
204
+ logging.error(f"Unexpected response format: {result}")
205
+ raise HTTPException(status_code=500, detail="Unexpected response format.")
206
+
207
+ # Assuming each embedding is a list of lists of floats, flatten it
208
+ flattened_embeddings = [sum(embedding, []) for embedding in result]
209
+ data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)]
210
 
 
 
 
 
 
 
 
 
 
 
211
  return EmbeddingResponse(
212
  object="list",
213
  data=data,
214
+ model=request.model_name,
215
+ usage={"prompt_tokens": len(request.input_text), "total_tokens": len(request.input_text)}
216
  )
217
  except Exception as e:
218
+ logging.error(f"An error occurred: {str(e)}")
219
  raise HTTPException(status_code=500, detail=str(e))
220
 
221
  def setup_routes(self):