Commit 2943a140 authored by it20118068's avatar it20118068

Model training

parent de205b7f
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0dbc2465",
"metadata": {},
"outputs": [],
"source": [
"# Import and Install Dependencies\n",
"import cv2\n",
"import numpy as np\n",
"import os\n",
"from matplotlib import pyplot as plt\n",
"import time\n",
"import mediapipe as mp"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9a7d095c",
"metadata": {},
"outputs": [],
"source": [
"#Keypoints using MP Holistic\n",
"mp_holistic = mp.solutions.holistic # Holistic model\n",
"mp_drawing = mp.solutions.drawing_utils # Drawing utilities"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "256a952f",
"metadata": {},
"outputs": [],
"source": [
"def mediapipe_detection(image, model):\n",
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB\n",
" image.flags.writeable = False # Image is no longer writeable\n",
" results = model.process(image) # Make prediction\n",
" image.flags.writeable = True # Image is now writeable \n",
" image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR\n",
" return image, results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "642e4cf5",
"metadata": {},
"outputs": [],
"source": [
"def draw_styled_landmarks(image, results):\n",
" # Draw face connections\n",
"# mp_drawing.draw_landmarks(image, results.face_landmarks, mp_holistic.FACEMESH_TESSELATION, \n",
"# mp_drawing.DrawingSpec(color=(80,110,10), thickness=1, circle_radius=1), \n",
"# mp_drawing.DrawingSpec(color=(80,256,121), thickness=1, circle_radius=1)\n",
"# ) \n",
" # Draw pose connections\n",
" mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS,\n",
" mp_drawing.DrawingSpec(color=(80,22,10), thickness=2, circle_radius=4), \n",
" mp_drawing.DrawingSpec(color=(80,44,121), thickness=2, circle_radius=2)\n",
" ) \n",
" # Draw left hand connections\n",
" mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS, \n",
" mp_drawing.DrawingSpec(color=(121,22,76), thickness=2, circle_radius=4), \n",
" mp_drawing.DrawingSpec(color=(121,44,250), thickness=2, circle_radius=2)\n",
" ) \n",
" # Draw right hand connections \n",
" mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS, \n",
" mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=4), \n",
" mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "94165384",
"metadata": {},
"outputs": [],
"source": [
"# Extract Keypoint Values\n",
"def extract_keypoints(results):\n",
" pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)\n",
"# face = np.array([[res.x, res.y, res.z] for res in results.face_landmarks.landmark]).flatten() if results.face_landmarks else np.zeros(468*3)\n",
" lh = np.array([[res.x, res.y, res.z] for res in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21*3)\n",
" rh = np.array([[res.x, res.y, res.z] for res in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21*3)\n",
"# return np.concatenate([pose, face, lh, rh])\n",
" return np.concatenate([pose, lh, rh])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "141bffd9",
"metadata": {},
"outputs": [],
"source": [
"# Setup Folders for Collection\n",
"\n",
"# Path for exported data, numpy arrays\n",
"DATA_PATH = os.path.join('VIDEO_DATA')\n",
"\n",
"VIDEO_PATH = 'VideoData'\n",
"\n",
"# Actions that we try to detect\n",
"actions = np.array(['tell','hello','mine','thankyou'])\n",
"\n",
"# Number of videos per action\n",
"no_videos = 10\n",
"\n",
"# Videos are going to be 30 frames in length\n",
"sequence_length = 30"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "730ffb12",
"metadata": {},
"outputs": [],
"source": [
"for action in actions: \n",
" for video in range(no_videos):\n",
" try: \n",
" os.makedirs(os.path.join(DATA_PATH, action, str(video)))\n",
" except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e4384d62",
"metadata": {},
"outputs": [],
"source": [
"with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:\n",
" \n",
" # Loop through actions\n",
" for action in actions:\n",
" # Loop through videos\n",
" for video in range(no_videos):\n",
" # Set path to the current video file\n",
" video_path = os.path.join(VIDEO_PATH, action, str(video) + \".mp4\")\n",
" cap = cv2.VideoCapture(video_path)\n",
" \n",
" # Get total number of frames in the video\n",
" total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
"\n",
" # Loop through video frames\n",
" for frame_num in range(sequence_length):\n",
"\n",
" # Calculate the frame index within the video\n",
" frame_index = int(frame_num * (total_frames / sequence_length))\n",
"\n",
" # Set the video frame position to the calculated index\n",
" cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)\n",
"\n",
" # Read frame from the video\n",
" ret, frame = cap.read()\n",
"\n",
" # Check if the video has ended\n",
" if not ret:\n",
" break\n",
"\n",
" # Make detections\n",
" image, results = mediapipe_detection(frame, holistic)\n",
" # print(results)\n",
"\n",
" # Draw landmarks\n",
" draw_styled_landmarks(image, results)\n",
"\n",
" # Apply wait logic\n",
" if frame_num == 0: \n",
" cv2.putText(image, 'STARTING COLLECTION', (120,200), \n",
" cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255, 0), 4, cv2.LINE_AA)\n",
" cv2.putText(image, 'Collecting frames for {} Video Number {}'.format(action.encode('utf-8'), video), (15,12), \n",
" cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)\n",
" # Show to screen\n",
" cv2.imshow('OpenCV Feed', image)\n",
" cv2.waitKey(2000)\n",
" else: \n",
" cv2.putText(image, 'Collecting frames for {} Video Number {}'.format(action, video), (15,12), \n",
" cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)\n",
" # Show to screen\n",
" cv2.imshow('OpenCV Feed', image)\n",
"\n",
" # Export keypoints\n",
" keypoints = extract_keypoints(results)\n",
" npy_path = os.path.join(DATA_PATH, action, str(video), str(frame_num))\n",
" np.save(npy_path, keypoints)\n",
"\n",
" # Break gracefully\n",
" if cv2.waitKey(10) & 0xFF == ord('q'):\n",
" break\n",
" \n",
" cap.release()\n",
" cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94b4b3f0-2b5a-49b1-a3d1-4a3d9114d4e5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment