Commit 8f46cc0d authored by Eshan Sanjaya's avatar Eshan Sanjaya

inference file

parent 3e65cbfe
{
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os, glob\n",
"import cv2 as cv\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"from transformers import pipeline, AutoProcessor, AutoModelForZeroShotObjectDetection"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n",
"Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n"
]
}
],
"source": [
"width = 224\n",
"height = 224\n",
"\n",
"target_size = (width, height)\n",
"input_shape = (width, height, 3)\n",
"\n",
"item_checkpoint = 'models/item identification.h5'\n",
"detection_checkpoint = \"google/owlvit-base-patch32\"\n",
"quality_checkpoint_hf = \"openai/clip-vit-large-patch14-336\"\n",
"quality_checkpoint = 'models/fresh-vs-rotten identification.h5'\n",
"\n",
"quality_classifier = pipeline(\n",
" \"zero-shot-image-classification\", \n",
" model = quality_checkpoint_hf\n",
" )\n",
"\n",
"detection_processor = AutoProcessor.from_pretrained(detection_checkpoint)\n",
"detection_model = AutoModelForZeroShotObjectDetection.from_pretrained(detection_checkpoint)\n",
"\n",
"quality_model = tf.keras.models.load_model(quality_checkpoint)\n",
"quality_model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ])\n",
"\n",
"item_model = tf.keras.models.load_model(item_checkpoint)\n",
"item_model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),\n",
" loss='categorical_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.CategoricalAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def inference_quality(\n",
" img,\n",
" text_queries = [\n",
" \"Fresh\", \n",
" \"Rotten\"\n",
" ]\n",
" ):\n",
" # img = cv.resize(img, target_size)\n",
" # img = tf.keras.applications.mobilenet_v2.preprocess_input(img)\n",
" # img = np.expand_dims(img, axis=0)\n",
" # pred = quality_model.predict(img)\n",
"\n",
" img = Image.fromarray(img)\n",
" response = quality_classifier(img, candidate_labels = text_queries)\n",
" scores = [r['score'] for r in response]\n",
" labels = [r['label'] for r in response]\n",
"\n",
" max_id = np.argmax(scores)\n",
" label = labels[max_id]\n",
" return label\n",
"\n",
"def inference_item(\n",
" img,\n",
" text_queries = [\n",
" \"apple\", \n",
" \"banana\", \n",
" \"grape\", \n",
" \"guava\", \n",
" \"jujube\", \n",
" \"orange\", \n",
" \"pomegranate\", \n",
" \"strawberry\"\n",
" ]\n",
" ):\n",
" img = Image.fromarray(img)\n",
" response = quality_classifier(img, candidate_labels = text_queries)\n",
" scores = [r['score'] for r in response]\n",
" labels = [r['label'] for r in response]\n",
"\n",
" max_id = np.argmax(scores)\n",
" label = labels[max_id]\n",
" return label"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def process_objects(\n",
" img_path,\n",
" text_queries = [\n",
" \"apple\", \n",
" \"banana\", \n",
" \"grape\", \n",
" \"guava\", \n",
" \"jujube\", \n",
" \"orange\", \n",
" \"pomegranate\", \n",
" \"strawberry\"\n",
" ]):\n",
" image = Image.open(img_path)\n",
" image_np = np.asarray(image)\n",
" inputs = detection_processor(\n",
" text=text_queries, \n",
" images=image, \n",
" return_tensors=\"pt\"\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" outputs = detection_model(**inputs)\n",
" target_sizes = torch.tensor([image.size[::-1]])\n",
" results = detection_processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]\n",
"\n",
" scores = results[\"scores\"].tolist()\n",
" labels = results[\"labels\"].tolist()\n",
" boxes = results[\"boxes\"].tolist()\n",
"\n",
" output = []\n",
" for box, _, _ in zip(boxes, scores, labels):\n",
" xmin, ymin, xmax, ymax = box\n",
" roi_item = image_np[int(ymin):int(ymax), int(xmin):int(xmax)]\n",
" item = inference_item(roi_item)\n",
" quality = inference_quality(roi_item)\n",
"\n",
" output.append({\n",
" \"item\": item,\n",
" \"quality\": quality,\n",
" })\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'item': 'apple', 'quality': 'Rotten'},\n",
" {'item': 'apple', 'quality': 'Rotten'}]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output = process_objects('data/item-detection/Apple/RottenApple (351).jpg')\n",
"output"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"import firebase_admin\n",
"from firebase_admin import db\n",
"from firebase_admin import credentials"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"node_blind_voice = 'blind_voice'\n",
"node_cart = 'cart'\n",
"cred = credentials.Certificate(\"models/my-eyes-8d270-firebase-adminsdk-uc43e-53f727863b.json\")\n",
"if firebase_admin._apps:\n",
" firebase_admin.delete_app(firebase_admin._apps['[DEFAULT]'])\n",
"default_app = firebase_admin.initialize_app(cred, {\n",
" 'databaseURL':'https://my-eyes-8d270-default-rtdb.firebaseio.com/'\n",
" })\n",
"\n",
"ref_node_blind_voice = db.reference(node_blind_voice)\n",
"ref_node_cart = db.reference(node_cart)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# while True:\n",
"# if ref_node_cart.get() is not None:\n",
"# data_cart = ref_node_cart.get()\n",
"# data_cart = data_cart['products']\n",
"# data_cart = list(data_cart.values())\n",
" \n",
"# cart_details = {}\n",
"# for item in data_cart:\n",
"# item_name = item['name'].strip().lower()\n",
"# quantity = item['qty']\n",
"\n",
"# item_name = item_name[0].upper() + item_name[1:]\n",
"# if item_name not in cart_details:\n",
"# cart_details[item_name] = quantity\n",
"\n",
"# cart_details_cp = cart_details.copy()\n",
"# for item, quantity in cart_details.items():\n",
"# item_images = glob.glob(f'data/item-detection/{item}/*.jpg')\n",
"# for i in range(quantity):\n",
"# while cart_details_cp[item] > 0:\n",
"# rand_img = np.random.choice(item_images)\n",
"# class_name = \"Fresh\" if \"Fresh\" in rand_img else \"Rotten\"\n",
"# speech_string = f\"selected {item} is {class_name}\"\n",
"# print(speech_string)\n",
"\n",
"# # update speech in ref_node_blind_voice node\n",
"# ref_node_blind_voice.update({\n",
"# 'speech': speech_string,\n",
"# 'location' : ref_node_blind_voice.get()['location']\n",
"# })\n",
"\n",
"\n",
"# if class_name == \"Fresh\":\n",
"# cart_details_cp[item] -= 1\n",
"\n",
" \n",
"# # delete the cart node\n",
"# ref_node_cart.delete()\n",
"# break"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# while True:\n",
"# if ref_node_cart.get() is not None:\n",
"# data_cart = ref_node_cart.get()\n",
"# print(data_cart)\n",
"# data_cart = data_cart['products']\n",
"# data_cart = list(data_cart.values())\n",
" \n",
"# cart_details = {}\n",
"# for item in data_cart:\n",
"# item_name = item['name'].strip().lower()\n",
"# quantity = item['qty']\n",
"\n",
"# if item_name not in cart_details:\n",
"# cart_details[item_name] = quantity\n",
"\n",
"# cap = cv.VideoCapture(0)\n",
"# while True:\n",
"# cart_details_cp = cart_details.copy() \n",
"# for i in range(cart_details_cp[item]):\n",
"# while cart_details_cp[item] > 0:\n",
"\n",
"# break \n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'apple': 3, 'grape': 1}\n",
"selected apple is Fresh\n",
"selected apple is Fresh\n",
"selected grape is Fresh\n",
"selected grape is Fresh\n"
]
}
],
"source": [
"while True:\n",
" if ref_node_cart.get() is not None:\n",
" data_cart = ref_node_cart.get()\n",
" data_cart = data_cart['products']\n",
" data_cart = list(data_cart.values())\n",
" \n",
" cart_details = {}\n",
" for item in data_cart:\n",
" item_name = item['name'].strip().lower()\n",
" quantity = item['qty']\n",
"\n",
" if item_name not in cart_details:\n",
" cart_details[item_name] = quantity\n",
" print(cart_details)\n",
" cap = cv.VideoCapture(0)\n",
" while True:\n",
" cart_details_cp = cart_details.copy()\n",
" for item, quantity in cart_details_cp.items():\n",
" for i in range(cart_details_cp[item]):\n",
" while cart_details_cp[item] > 0:\n",
"\n",
" image = cap.read()[1]\n",
" # image = cv.cvtColor(image, cv.COLOR_BGR2RGB)\n",
" cv.imwrite('tmp.jpg', image)\n",
"\n",
" # plt.imshow(image)\n",
" # plt.show() \n",
"\n",
" cv2.imshow(\"CART\", image)\n",
"\n",
" k = cv2.waitKey(1)\n",
" if k%256 == 27:\n",
" print(\"Escape hit, closing...\")\n",
" break \n",
"\n",
" try:\n",
" output = process_objects('tmp.jpg')\n",
" except:\n",
" pass\n",
"\n",
" if len(output) > 0:\n",
" for o in output:\n",
" item_det = o['item']\n",
" quality_det = o['quality']\n",
" \n",
" if item_det in cart_details_cp:\n",
" if quality_det == \"Fresh\":\n",
" cart_details_cp[item] -= 1\n",
" speech_string = f\"selected {item_det} is {quality_det}\"\n",
" print(speech_string)\n",
"\n",
" # update speech in ref_node_blind_voice node\n",
" ref_node_blind_voice.update({\n",
" 'speech': speech_string,\n",
" 'location' : ref_node_blind_voice.get()['location']\n",
" })\n",
" \n",
"\n",
" for item, quantity in cart_details_cp.items(): \n",
" if quantity == 0:\n",
" del cart_details[item]\n",
"\n",
" if len(cart_details) == 0:\n",
" ref_node_cart.delete()\n",
" break\n",
"\n",
" cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment