okeowo1014 commited on
Commit
4597f1c
1 Parent(s): d09aad8

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +71 -0
main.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from fastapi import FastAPI, File, UploadFile
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch.nn.functional as F
9
+ import torch.nn as nn
10
+ num_classes = 10
11
+
12
+ # Class definition for the model (same as in your code)
13
+ class FingerprintRecognitionModel(nn.Module):
14
+ def __init__(self, num_classes):
15
+ super(FingerprintRecognitionModel, self).__init__()
16
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
17
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
18
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
19
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
20
+ self.fc1 = nn.Linear(128 * 28 * 28, 256)
21
+ self.fc2 = nn.Linear(256, num_classes)
22
+
23
+ def forward(self, x):
24
+ x = self.pool(F.relu(self.conv1(x)))
25
+ x = self.pool(F.relu(self.conv2(x)))
26
+ x = self.pool(F.relu(self.conv3(x)))
27
+ x = x.view(-1, 128 * 28 * 28)
28
+ x = F.relu(self.fc1(x))
29
+ x = F.softmax(self.fc2(x), dim=1)
30
+ return x
31
+
32
+ app = FastAPI()
33
+
34
+ # Load the model
35
+ model_path = 'fingerprint_recognition_model.pt'
36
+ model = FingerprintRecognitionModel(num_classes)
37
+ model.load_state_dict(torch.load(model_path))
38
+ model.eval()
39
+
40
+ def preprocess_image(image_bytes):
41
+ # Convert bytes to PIL Image
42
+ image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale
43
+
44
+ # Resize to 224x224
45
+ img_resized = image.resize((224, 224))
46
+
47
+ transform = transforms.Compose([
48
+ transforms.ToTensor(),
49
+ transforms.Normalize((0.5,), (0.5,))
50
+ ])
51
+
52
+ # Apply transforms and add batch dimension
53
+ img_tensor = transform(img_resized).unsqueeze(0)
54
+
55
+ return img_tensor
56
+
57
+ def predict_class(image_bytes):
58
+ img_tensor = preprocess_image(image_bytes)
59
+ with torch.no_grad():
60
+ outputs = model(img_tensor)
61
+ _, predicted = torch.max(outputs.data, 1)
62
+ predicted_class = int(predicted.item())
63
+ return predicted_class
64
+
65
+ @app.post("/predict/")
66
+ async def predict_endpoint(file: UploadFile = File(...)):
67
+ contents = await file.read()
68
+ predicted_class = predict_class(contents)
69
+ class_labels = {0: 'Left_ring_fingers', 1: 'Left_thumb_fingers', 2: 'Right_index_fingers', 3: 'Right_little_fingers', 4: 'Right_middle_fingers', 5: 'Right_ring_fingers', 6: 'Right_thumb_fingers', 7: 'left_index_fingers', 8: 'left_little_fingers', 9: 'left_middle_fingers'}
70
+ return {"predicted_class": predicted_class, "class_label": class_labels[predicted_class]}
71
+