Commit 958b0f37 authored by Paranagama R.P.S.D.'s avatar Paranagama R.P.S.D.

feat : Get predictions with precentage & threshold value

parent 14bf66ef
......@@ -13,3 +13,6 @@
2023-05-24 20:05:37,936 - INFO - Error.
2023-05-24 20:05:37,936 - INFO - Error.
2023-05-24 20:05:37,936 - INFO - Error.
2023-07-12 05:50:25,202 - INFO - Error. 'SignLanguagePredictionService' object has no attribute 'predict_sign_language_video2'
2023-07-12 05:50:25,202 - INFO - Error. 'SignLanguagePredictionService' object has no attribute 'predict_sign_language_video2'
2023-07-12 05:50:25,202 - INFO - Error. 'SignLanguagePredictionService' object has no attribute 'predict_sign_language_video2'
......@@ -53,7 +53,7 @@ def predict_using_image(image_request: UploadFile = File(...)):
@router.post('/predict-sign-language/video', tags=["Sign Language"])
def predict_using_video(video_request: UploadFile = File(...)):
try:
return prediction_service.predict_sign_language_video2(video_request)
return prediction_service.predict_sign_language_video_new(video_request)
except Exception as e:
logger.info(f"Error. {e}")
raise HTTPException(
......
......@@ -3,7 +3,7 @@ import cv2
import numpy as np
from fastapi import HTTPException, UploadFile
from typing import Dict
from typing import Counter, Dict
from core.logger import setup_logger
......@@ -95,7 +95,7 @@ class SignLanguagePredictionService:
)
def predict_sign_language_video(self, video_request: UploadFile) -> Dict[str, str]:
def predict_sign_language_video_new(self, video_request: UploadFile) -> Dict[str, str]:
try:
# Create a temporary file to save the video
video_location = f"files/{video_request.filename}"
......@@ -120,6 +120,8 @@ class SignLanguagePredictionService:
frame = extract_hand_shape(frame)
frame = np.array([frame], dtype=np.float32) / 255.0
# Make prediction
prediction = self.model.predict(frame)
class_index = np.argmax(prediction)
......@@ -136,7 +138,11 @@ class SignLanguagePredictionService:
# Delete the video file
os.remove(video_location)
return {'frame_count': frame_count, 'predictions': predictions}
threshold_percentage = 60
predictions = get_predicted_percentage(predictions, threshold_percentage)
return {'frame_count': frame_count, 'predictions': predictions }
except Exception as e:
logger.info(f"Failed to make predictions. {e}")
raise HTTPException(
......@@ -153,4 +159,18 @@ def extract_hand_shape(image):
hand_contour = contours[0]
hand_shape = np.zeros_like(image)
cv2.drawContours(hand_shape, [hand_contour], 0, (255, 255, 255), thickness=cv2.FILLED)
return hand_shape
\ No newline at end of file
return hand_shape
def get_predicted_percentage(array, threshold):
counts = Counter(array)
total_elements = len(array)
percentages = {}
for element, count in counts.items():
percentage = (count / total_elements) * 100
percentages[element] = percentage
elements_above_threshold = [element for element, percentage in percentages.items() if percentage > threshold]
return elements_above_threshold
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment