Inference - Python Backend

parent ba0ca7e8
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os, glob\n",
"import cv2 as cv\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"from transformers import pipeline, AutoProcessor, AutoModelForZeroShotObjectDetection"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n",
"Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n"
]
}
],
"source": [
"width = 224\n",
"height = 224\n",
"\n",
"target_size = (width, height)\n",
"input_shape = (width, height, 3)\n",
"\n",
"item_checkpoint = 'models/item identification.h5'\n",
"detection_checkpoint = \"google/owlvit-base-patch32\"\n",
"quality_checkpoint_hf = \"openai/clip-vit-large-patch14-336\"\n",
"quality_checkpoint = 'models/fresh-vs-rotten identification.h5'\n",
"\n",
"quality_classifier = pipeline(\n",
" \"zero-shot-image-classification\", \n",
" model = quality_checkpoint_hf\n",
" )\n",
"\n",
"detection_processor = AutoProcessor.from_pretrained(detection_checkpoint)\n",
"detection_model = AutoModelForZeroShotObjectDetection.from_pretrained(detection_checkpoint)\n",
"\n",
"quality_model = tf.keras.models.load_model(quality_checkpoint)\n",
"quality_model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ])\n",
"\n",
"item_model = tf.keras.models.load_model(item_checkpoint)\n",
"item_model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),\n",
" loss='categorical_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.CategoricalAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def inference_quality(\n",
" img,\n",
" text_queries = [\n",
" \"Fresh\", \n",
" \"Rotten\"\n",
" ]\n",
" ):\n",
" # img = cv.resize(img, target_size)\n",
" # img = tf.keras.applications.mobilenet_v2.preprocess_input(img)\n",
" # img = np.expand_dims(img, axis=0)\n",
" # pred = quality_model.predict(img)\n",
"\n",
" img = Image.fromarray(img)\n",
" response = quality_classifier(img, candidate_labels = text_queries)\n",
" scores = [r['score'] for r in response]\n",
" labels = [r['label'] for r in response]\n",
"\n",
" max_id = np.argmax(scores)\n",
" label = labels[max_id]\n",
" return label\n",
"\n",
"def inference_item(\n",
" img,\n",
" text_queries = [\n",
" \"apple\", \n",
" \"banana\", \n",
" \"grape\", \n",
" \"guava\", \n",
" \"jujube\", \n",
" \"orange\", \n",
" \"pomegranate\", \n",
" \"strawberry\"\n",
" ]\n",
" ):\n",
" img = Image.fromarray(img)\n",
" response = quality_classifier(img, candidate_labels = text_queries)\n",
" scores = [r['score'] for r in response]\n",
" labels = [r['label'] for r in response]\n",
"\n",
" max_id = np.argmax(scores)\n",
" label = labels[max_id]\n",
" return label"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def process_objects(\n",
" img_path,\n",
" text_queries = [\n",
" \"apple\", \n",
" \"banana\", \n",
" \"grape\", \n",
" \"guava\", \n",
" \"jujube\", \n",
" \"orange\", \n",
" \"pomegranate\", \n",
" \"strawberry\"\n",
" ]):\n",
" image = Image.open(img_path)\n",
" image_np = np.asarray(image)\n",
" inputs = detection_processor(\n",
" text=text_queries, \n",
" images=image, \n",
" return_tensors=\"pt\"\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" outputs = detection_model(**inputs)\n",
" target_sizes = torch.tensor([image.size[::-1]])\n",
" results = detection_processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]\n",
"\n",
" scores = results[\"scores\"].tolist()\n",
" labels = results[\"labels\"].tolist()\n",
" boxes = results[\"boxes\"].tolist()\n",
"\n",
" output = []\n",
" for box, _, _ in zip(boxes, scores, labels):\n",
" xmin, ymin, xmax, ymax = box\n",
" roi_item = image_np[int(ymin):int(ymax), int(xmin):int(xmax)]\n",
" quality = inference_quality(roi_item)\n",
" item = inference_item(roi_item)\n",
" output.append({\n",
" \"item\": item,\n",
" \"quality\": quality,\n",
" })\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'item': 'apple', 'quality': 'Rotten'},\n",
" {'item': 'apple', 'quality': 'Rotten'}]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output = process_objects('data/item-detection/Apple/RottenApple (351).jpg')\n",
"output"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import firebase_admin\n",
"from firebase_admin import db\n",
"from firebase_admin import credentials"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"node_blind_voice = 'blind_voice'\n",
"node_cart = 'cart'\n",
"cred = credentials.Certificate(\"models/my-eyes-8d270-firebase-adminsdk-uc43e-53f727863b.json\")\n",
"if firebase_admin._apps:\n",
" firebase_admin.delete_app(firebase_admin._apps['[DEFAULT]'])\n",
"default_app = firebase_admin.initialize_app(cred, {\n",
" 'databaseURL':'https://my-eyes-8d270-default-rtdb.firebaseio.com/'\n",
" })\n",
"\n",
"ref_node_blind_voice = db.reference(node_blind_voice)\n",
"ref_node_cart = db.reference(node_cart)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"# while True:\n",
"# if ref_node_cart.get() is not None:\n",
"# data_cart = ref_node_cart.get()\n",
"# data_cart = data_cart['products']\n",
"# data_cart = list(data_cart.values())\n",
" \n",
"# cart_details = {}\n",
"# for item in data_cart:\n",
"# item_name = item['name'].strip().lower()\n",
"# quantity = item['qty']\n",
"\n",
"# item_name = item_name[0].upper() + item_name[1:]\n",
"# if item_name not in cart_details:\n",
"# cart_details[item_name] = quantity\n",
"\n",
"# cart_details_cp = cart_details.copy()\n",
"# for item, quantity in cart_details.items():\n",
"# item_images = glob.glob(f'data/item-detection/{item}/*.jpg')\n",
"# for i in range(quantity):\n",
"# while cart_details_cp[item] > 0:\n",
"# rand_img = np.random.choice(item_images)\n",
"# class_name = \"Fresh\" if \"Fresh\" in rand_img else \"Rotten\"\n",
"# speech_string = f\"selected {item} is {class_name}\"\n",
"# print(speech_string)\n",
"\n",
"# # update speech in ref_node_blind_voice node\n",
"# ref_node_blind_voice.update({\n",
"# 'speech': speech_string,\n",
"# 'location' : ref_node_blind_voice.get()['location']\n",
"# })\n",
"\n",
"\n",
"# if class_name == \"Fresh\":\n",
"# cart_details_cp[item] -= 1\n",
"\n",
" \n",
"# # delete the cart node\n",
"# ref_node_cart.delete()\n",
"# break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"while True:\n",
" if ref_node_cart.get() is not None:\n",
" data_cart = ref_node_cart.get()\n",
" data_cart = data_cart['products']\n",
" data_cart = list(data_cart.values())\n",
" \n",
" cart_details = {}\n",
" for item in data_cart:\n",
" item_name = item['name'].strip().lower()\n",
" quantity = item['qty']\n",
"\n",
" if item_name not in cart_details:\n",
" cart_details[item_name] = quantity\n",
"\n",
" cap = cv.VideoCapture(0)\n",
" while True:\n",
" cart_details_cp = cart_details.copy()\n",
" for i in range(quantity):\n",
" while cart_details_cp[item] > 0:\n",
"\n",
" image = cap.read()[1]\n",
" image = cv.cvtColor(image, cv.COLOR_BGR2RGB)\n",
" cv.imwrite('tmp.jpg', image)\n",
"\n",
" output = process_objects('tmp.jpg')\n",
"\n",
" if len(output) > 0:\n",
" for o in output:\n",
" item = o['item']\n",
" quality = o['quality']\n",
" if quality == \"Fresh\":\n",
" cart_details_cp[item] -= 1\n",
" speech_string = f\"selected {item} is {quality}\"\n",
" print(speech_string)\n",
"\n",
" # update speech in ref_node_blind_voice node\n",
" ref_node_blind_voice.update({\n",
" 'speech': speech_string,\n",
" 'location' : ref_node_blind_voice.get()['location']\n",
" })\n",
" \n",
"\n",
" for item, quantity in cart_details_cp.items(): \n",
" if quantity == 0:\n",
" del cart_details_cp[item]\n",
"\n",
" if len(cart_details_cp) == 0:\n",
" ref_node_cart.delete()\n",
" break"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment