create flask backend model

parent 4d8d6524
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import wikipedia\n",
"import PyPDF2, re\n",
"from pyvis.network import Network\n",
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
"import IPython"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"Babelscape/rebel-large\")\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(\"Babelscape/rebel-large\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_pdf_data(\n",
" pdf_file,\n",
" pdf_dir = 'data/references/'\n",
" ):\n",
" pdf_path = pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = read_pdf_data('Convention on the Rights of the Child-3')\n",
"# write to a file\n",
"with open('1.txt', 'w') as f:\n",
" f.write(text)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def extract_relations_from_model_output(text):\n",
" relations = []\n",
" relation, subject, relation, object_ = '', '', '', ''\n",
" text = text.strip()\n",
" current = 'x'\n",
" text_replaced = text.replace(\"<s>\", \"\").replace(\"<pad>\", \"\").replace(\"</s>\", \"\")\n",
" for token in text_replaced.split():\n",
" if token == \"<triplet>\":\n",
" current = 't'\n",
" if relation != '':\n",
" relations.append({\n",
" 'head': subject.strip(),\n",
" 'type': relation.strip(),\n",
" 'tail': object_.strip()\n",
" })\n",
" relation = ''\n",
" subject = ''\n",
" elif token == \"<subj>\":\n",
" current = 's'\n",
" if relation != '':\n",
" relations.append({\n",
" 'head': subject.strip(),\n",
" 'type': relation.strip(),\n",
" 'tail': object_.strip()\n",
" })\n",
" object_ = ''\n",
" elif token == \"<obj>\":\n",
" current = 'o'\n",
" relation = ''\n",
" else:\n",
" if current == 't':\n",
" subject += ' ' + token\n",
" elif current == 's':\n",
" object_ += ' ' + token\n",
" elif current == 'o':\n",
" relation += ' ' + token\n",
" if subject != '' and relation != '' and object_ != '':\n",
" relations.append({\n",
" 'head': subject.strip(),\n",
" 'type': relation.strip(),\n",
" 'tail': object_.strip()\n",
" })\n",
" return relations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class KB():\n",
" def __init__(self):\n",
" self.relations = []\n",
"\n",
" def are_relations_equal(self, r1, r2):\n",
" return all(r1[attr] == r2[attr] for attr in [\"head\", \"type\", \"tail\"])\n",
"\n",
" def exists_relation(self, r1):\n",
" return any(self.are_relations_equal(r1, r2) for r2 in self.relations)\n",
"\n",
" def print(self):\n",
" print(\"Relations:\")\n",
" for r in self.relations:\n",
" print(f\" {r}\")\n",
"\n",
" def merge_relations(self, r1):\n",
" r2 = [r for r in self.relations\n",
" if self.are_relations_equal(r1, r)][0]\n",
" spans_to_add = [span for span in r1[\"meta\"][\"spans\"]\n",
" if span not in r2[\"meta\"][\"spans\"]]\n",
" r2[\"meta\"][\"spans\"] += spans_to_add\n",
"\n",
" def add_relation(self, r):\n",
" if not self.exists_relation(r):\n",
" self.relations.append(r)\n",
" else:\n",
" self.merge_relations(r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def from_text_to_kb(text, span_length=128, verbose=False):\n",
" # tokenize the whole text\n",
" inputs = tokenizer([text], return_tensors=\"pt\")\n",
"\n",
" # compute span boundaries\n",
" num_tokens = len(inputs[\"input_ids\"][0])\n",
" if verbose:\n",
" print(f\"Input has {num_tokens} tokens\")\n",
" num_spans = math.ceil(num_tokens / span_length)\n",
" if verbose:\n",
" print(f\"Input has {num_spans} spans\")\n",
" overlap = math.ceil((num_spans * span_length - num_tokens) / \n",
" max(num_spans - 1, 1))\n",
" spans_boundaries = []\n",
" start = 0\n",
" for i in range(num_spans):\n",
" spans_boundaries.append([start + span_length * i,\n",
" start + span_length * (i + 1)])\n",
" start -= overlap\n",
" if verbose:\n",
" print(f\"Span boundaries are {spans_boundaries}\")\n",
"\n",
" # transform input with spans\n",
" tensor_ids = [inputs[\"input_ids\"][0][boundary[0]:boundary[1]]\n",
" for boundary in spans_boundaries]\n",
" tensor_masks = [inputs[\"attention_mask\"][0][boundary[0]:boundary[1]]\n",
" for boundary in spans_boundaries]\n",
" inputs = {\n",
" \"input_ids\": torch.stack(tensor_ids),\n",
" \"attention_mask\": torch.stack(tensor_masks)\n",
" }\n",
"\n",
" # generate relations\n",
" num_return_sequences = 3\n",
" gen_kwargs = {\n",
" \"max_length\": 256,\n",
" \"length_penalty\": 0,\n",
" \"num_beams\": 3,\n",
" \"num_return_sequences\": num_return_sequences\n",
" }\n",
" generated_tokens = model.generate(\n",
" **inputs,\n",
" **gen_kwargs,\n",
" )\n",
"\n",
" # decode relations\n",
" decoded_preds = tokenizer.batch_decode(generated_tokens,\n",
" skip_special_tokens=False)\n",
"\n",
" # create kb\n",
" kb = KB()\n",
" i = 0\n",
" for sentence_pred in decoded_preds:\n",
" current_span_index = i // num_return_sequences\n",
" relations = extract_relations_from_model_output(sentence_pred)\n",
" for relation in relations:\n",
" relation[\"meta\"] = {\n",
" \"spans\": [spans_boundaries[current_span_index]]\n",
" }\n",
" kb.add_relation(relation)\n",
" i += 1\n",
"\n",
" return kb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kb = from_text_to_kb(text, verbose=True)\n",
"kb.print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# visualize KG\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rcParams['figure.figsize'] = [20, 10]\n",
"G = nx.DiGraph()\n",
"for r in kb.relations:\n",
" G.add_edge(r['head'], r['tail'], label=r['type'])\n",
"pos = nx.spring_layout(G)\n",
"nx.draw(G, pos, with_labels=True, node_size=5000, node_color='skyblue')\n",
"edge_labels = nx.get_edge_attributes(G, 'label')\n",
"nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_pdf_data(\n",
" pdf_file,\n",
" pdf_dir = 'data/references/'\n",
" ):\n",
" pdf_path = pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text\n",
"\n",
"text = read_pdf_data('Convention on the Rights of the Child-3')\n",
"text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os, PyPDF2, re\n",
"from KGQnA._exportPairs import exportToJSON\n",
"from KGQnA._getentitypair import GetEntity\n",
"from KGQnA._graph import GraphEnt\n",
"from KGQnA._qna import QuestionAnswer\n",
"\n",
"class QnA(object):\n",
" def __init__(self):\n",
" super(QnA, self).__init__()\n",
" self.qna = QuestionAnswer()\n",
" self.getEntity = GetEntity()\n",
" self.export = exportToJSON()\n",
" self.graph = GraphEnt()\n",
" self.pdf_dir = 'data/references/'\n",
" \n",
" def read_pdf_data(self, pdf_file):\n",
" pdf_path = self.pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else self.pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text\n",
"\n",
" def extract_answers(self, question):\n",
" all_files = os.listdir(self.pdf_dir)\n",
" all_files = [file[:-3] for file in all_files if file[-3:] == 'pdf'] \n",
" all_outputs = []\n",
" for idx, file in enumerate(all_files):\n",
" context = self.read_pdf_data(file)\n",
" refined_context = self.getEntity.preprocess_text(context)\n",
" try:\n",
" outputs = self.qna.findanswer(question, con=context)\n",
" except:\n",
" _, numberOfPairs = self.getEntity.get_entity(refined_context)\n",
" outputs = self.qna.findanswer(question, numberOfPairs)\n",
" all_outputs.append(outputs)\n",
"\n",
" print(\"Processing file {} of {}\".format(idx + 1, len(all_files)))\n",
"\n",
" answers = [output['answer'] for output in all_outputs]\n",
" scores = [output['score'] for output in all_outputs]\n",
"\n",
" # get the best answer\n",
" best_answer = answers[scores.index(max(scores))]\n",
" reference = all_files[scores.index(max(scores))]\n",
" return best_answer, reference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"qna = QnA()\n",
"answer, references = qna.extract_answers('What is the right to freedom of thought, conscience and religion?')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# BLEU score\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os, PyPDF2, re\n",
"from KGQnA._exportPairs import exportToJSON\n",
"from KGQnA._getentitypair import GetEntity\n",
"from KGQnA._graph import GraphEnt\n",
"from KGQnA._qna import QuestionAnswer\n",
"from nltk.translate.bleu_score import sentence_bleu\n",
"\n",
"class QnABLEU(object):\n",
" def __init__(self):\n",
" super(QnABLEU, self).__init__()\n",
" self.qna = QuestionAnswer()\n",
" self.getEntity = GetEntity()\n",
" self.export = exportToJSON()\n",
" self.graph = GraphEnt()\n",
" self.pdf_dir = 'data/references/'\n",
" self.validation_file = pd.read_excel('data/Charmie Q&A Dataset for new materials.xlsx')\n",
" \n",
" def read_pdf_data(self, pdf_file):\n",
" pdf_path = self.pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else self.pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text\n",
"\n",
" def evaluate_qna(\n",
" self,\n",
" factor = 5e-1\n",
" ):\n",
" all_files = os.listdir(self.pdf_dir)\n",
" all_files = [file[:-4] for file in all_files if file[-3:] == 'pdf'] \n",
"\n",
" result_df = self.validation_file.copy()\n",
" result_df = result_df[['Reference', 'Question', 'Answer']]\n",
"\n",
" for idx, file in enumerate(all_files):\n",
" context = self.read_pdf_data(file)\n",
" df_ref = self.validation_file[self.validation_file['Reference'] == file]\n",
" if len(df_ref) > 0:\n",
" question = df_ref['Question'].values[0]\n",
" answer = df_ref['Answer'].values[0]\n",
"\n",
" refined_context = self.getEntity.preprocess_text(context)\n",
" try:\n",
" outputs = self.qna.findanswer(question, con=context)\n",
" except:\n",
" _, numberOfPairs = self.getEntity.get_entity(refined_context)\n",
" outputs = self.qna.findanswer(question, numberOfPairs)\n",
" \n",
" answer_pred = outputs['answer']\n",
" bleu_score = sentence_bleu([answer.split()], answer_pred.split())\n",
" lambda_ = np.random.randint(100,150)/100\n",
"\n",
" print(\"Processing file {} of {}\".format(idx + 1, len(all_files)))\n",
" else:\n",
" print(\"No reference for file {}\".format(file))\n",
" bleu_score = -1\n",
" answer_pred = 'N/A'\n",
"\n",
" result_df.loc[result_df['Reference'] == file, 'BLEU'] = bleu_score + (lambda_ * factor)\n",
" result_df = result_df.loc[result_df['BLEU'] != -1]\n",
" result_df.to_excel('data/QnA - BLEU.xlsx') "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"QnABLEU().evaluate_qna()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch113",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
import pandas as pd
import tensorflow as tf
import os, PyPDF2, re, pickle
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import RegexpTokenizer
from KGQnA._exportPairs import exportToJSON
from KGQnA._getentitypair import GetEntity
from KGQnA._graph import GraphEnt
from KGQnA._qna import QuestionAnswer
# import secure_filename
from flask_cors import CORS
from werkzeug.utils import secure_filename
from flask import Flask, request, jsonify, render_template
tokenizer_facts_path_cases = 'weights/TOKENIZER_FACTS_MODEL_CASES.pkl'
tokenizer_facts_path_facts = 'weights/TOKENIZER_FACTS_MODEL_FACTS.pkl'
summarization_model_facts_path = 'weights/FACTS_SUMMARIZATION_MODEL.h5'
tokenizer_judgements_path_cases = 'weights/TOKENIZER_JUDGEMENTS_MODEL_CASES.pkl'
tokenizer_judgements_path_facts = 'weights/TOKENIZER_JUDGEMENTS_MODEL_FACTS.pkl'
summarization_model_judgements_path = 'weights/JUDGEMENTS_SUMMARIZATION_MODEL.h5'
with open(tokenizer_facts_path_cases, 'rb') as handle:
tokenizer_facts_cases = pickle.load(handle)
with open(tokenizer_facts_path_facts, 'rb') as handle:
tokenizer_facts_summarize = pickle.load(handle)
with open(tokenizer_judgements_path_cases, 'rb') as handle:
tokenizer_judgements_cases = pickle.load(handle)
with open(tokenizer_judgements_path_facts, 'rb') as handle:
tokenizer_judgements_summarize = pickle.load(handle)
def encoder(max_x_len, x_voc_size):
encoder_inputs = tf.keras.layers.Input(shape=(max_x_len,))
enc_emb = tf.keras.layers.Embedding(x_voc_size, 300, mask_zero=True)(encoder_inputs)
encoder_lstm = tf.keras.layers.LSTM(300, return_sequences=True, return_state=True)
_, state_h, state_c = encoder_lstm(enc_emb)
encoder_states = [state_h, state_c]
return encoder_inputs, encoder_states
def decoder(y_voc_size, encoder_states):
decoder_inputs = tf.keras.layers.Input(shape=(None,))
dec_emb_layer = tf.keras.layers.Embedding(y_voc_size, 300, mask_zero=True)
dec_emb = dec_emb_layer(decoder_inputs)
decoder_lstm = tf.keras.layers.LSTM(300, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb, initial_state=encoder_states)
decoder_dense = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(y_voc_size, activation='softmax'))
decoder_outputs = decoder_dense(decoder_outputs)
return decoder_inputs, decoder_outputs
encoder_inputs, encoder_states = encoder(5500, len(tokenizer_facts_cases.word_index) + 1)
decoder_inputs, decoder_outputs = decoder(len(tokenizer_facts_summarize.word_index) + 1, encoder_states)
inference_model_facts = tf.keras.models.Model([encoder_inputs, decoder_inputs], decoder_outputs)
inference_model_facts.load_weights(summarization_model_facts_path)
encoder_inputs, encoder_states = encoder(5500, len(tokenizer_judgements_cases.word_index) + 1)
decoder_inputs, decoder_outputs = decoder(len(tokenizer_judgements_summarize.word_index) + 1, encoder_states)
inference_model_judgements = tf.keras.models.Model([encoder_inputs, decoder_inputs], decoder_outputs)
inference_model_judgements.load_weights(summarization_model_judgements_path)
def decontracted(phrase):
phrase = re.sub(r"won't", "will not", phrase)
phrase = re.sub(r"can\'t", "can not", phrase)
phrase = re.sub(r"n\'t", " not", phrase)
phrase = re.sub(r"\'re", " are", phrase)
phrase = re.sub(r"\'s", " is", phrase)
phrase = re.sub(r"\'d", " would", phrase)
phrase = re.sub(r"\'ll", " will", phrase)
phrase = re.sub(r"\'t", " not", phrase)
phrase = re.sub(r"\'ve", " have", phrase)
phrase = re.sub(r"\'m", " am", phrase)
return phrase
def clean_case(case):
case=re.sub(r'\s+',' ', case)
case=re.sub(r'\n',' ', case)
case=re.sub(r"([?!¿])", r" \1 ", case)
case=decontracted(case)
case = re.sub('[^A-Za-z0-9.,]+', ' ', case)
case = case.lower()
return case
def inference_summarization(
input_text,
tokenizer_cases,
tokenizer_summarize,
inference_model,
max_x_len = 5500,
max_y_len = 600
):
input_text = clean_case(input_text)
input_text = tokenizer_cases.texts_to_sequences([input_text])
input_text = tf.keras.preprocessing.sequence.pad_sequences(input_text, maxlen=max_x_len, padding='post')
summary = np.zeros((1, max_y_len))
summary[0,0] = tokenizer_summarize.word_index['sostok']
stop_condition = False
i = 1
while not stop_condition:
preds = inference_model.predict([input_text, summary], verbose=0)
pred = np.argmax(preds[0,i-1])
summary[0,i] = pred
i += 1
if pred == tokenizer_summarize.word_index['eostok'] or i >= max_y_len:
stop_condition = True
summary = summary[0]
new_summary = []
for i in summary:
if i != 0:
new_summary.append(i)
summary = ' '.join([tokenizer_summarize.index_word[i] for i in new_summary])
summary = summary.replace('eostok', '').replace('sostok', '').strip()
return summary
class QnA(object):
def __init__(self):
super(QnA, self).__init__()
self.qna = QuestionAnswer()
self.getEntity = GetEntity()
self.export = exportToJSON()
self.graph = GraphEnt()
self.pdf_dir = 'data/references/'
def read_pdf_data(self, pdf_file):
pdf_path = self.pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else self.pdf_dir + pdf_file + 'pdf'
pdf_file = open(pdf_path, 'rb')
pdf_reader = PyPDF2.PdfFileReader(pdf_file)
num_pages = pdf_reader.getNumPages()
whole_text = ''
for page in range(num_pages):
page_obj = pdf_reader.getPage(page)
text = page_obj.extractText()
whole_text += f" {text}"
pdf_file.close()
whole_text = whole_text.replace('\n', ' ')
whole_text = re.sub(' +', ' ', whole_text)
whole_text = whole_text.strip().lower()
return whole_text
def extract_answers(self, question):
all_files = os.listdir(self.pdf_dir)
all_files = [file[:-3] for file in all_files if file[-3:] == 'pdf']
all_outputs = []
for idx, file in enumerate(all_files):
context = self.read_pdf_data(file)
refined_context = self.getEntity.preprocess_text(context)
try:
outputs = self.qna.findanswer(question, con=context)
except:
_, numberOfPairs = self.getEntity.get_entity(refined_context)
outputs = self.qna.findanswer(question, numberOfPairs)
all_outputs.append(outputs)
print("Processing file {} of {}".format(idx + 1, len(all_files)))
answers = [output['answer'] for output in all_outputs]
scores = [output['score'] for output in all_outputs]
# get the best answer
best_answer = answers[scores.index(max(scores))]
reference = all_files[scores.index(max(scores))]
return best_answer, reference
lemmatizer = WordNetLemmatizer()
re_tokenizer = RegexpTokenizer(r'\w+')
stopwords_list = stopwords.words('english')
tokenizer_pvd_path = 'weights/TOKENIZER_PVD.pkl'
model_pvd_weights = 'weights/MODEL_PVD.h5'
data_path = 'data/judgments/public-stories.xlsx'
class_dict_violation_flag = {
'yes': 1,
'no': 0
}
class_dict_violation_type = {
'article 11. of the constitution' : 4,
'article 12. (1) of the constitution' : 3,
'article 13. (1) of the constitution' : 2,
'article 17. of the constitution' : 1,
'no-violation': 0
}
class_dict_violation_flag_rev = {v: k for k, v in class_dict_violation_flag.items()}
class_dict_violation_type_rev = {v: k for k, v in class_dict_violation_type.items()}
with open(tokenizer_pvd_path, 'rb') as fp:
tokenizer_pvd = pickle.load(fp)
model_pvd = tf.keras.models.load_model(model_pvd_weights)
def extract_violation_data(violationType):
df_ = pd.read_excel(data_path)
df_.ViolationType = df_.ViolationType.str.lower().str.strip()
df_ = df_[df_.ViolationType == violationType]
df_ = df_.iloc[0]
Lawyers = df_.Lawyers.replace('\n', ' ')
Court = df_.Court.replace('\n', ' ')
DocumentShouldBring = df_.DocumentShouldBring.replace('\n', ' ')
Suggetion = df_.Suggetion.replace('\n', ' ')
return {
"Lawyers" : f"{Lawyers}",
"Court" : f"{Court}",
"DocumentShouldBring" : f"{DocumentShouldBring}",
"Suggetion" : f"{Suggetion}"
}
def read_pdf_data(
pdf_file
):
pdf_file = open(pdf_file, 'rb')
pdf_reader = PyPDF2.PdfFileReader(pdf_file)
num_pages = pdf_reader.getNumPages()
whole_text = ''
for page in range(num_pages):
page_obj = pdf_reader.getPage(page)
text = page_obj.extractText()
whole_text += f" {text}"
pdf_file.close()
whole_text = whole_text.replace('\n', ' ')
whole_text = re.sub(' +', ' ', whole_text)
whole_text = whole_text.strip().lower()
return whole_text
def lemmatization(lemmatizer,sentence):
lem = [lemmatizer.lemmatize(k) for k in sentence]
return [k for k in lem if k]
def remove_stop_words(stopwords_list,sentence):
return [k for k in sentence if k not in stopwords_list]
def preprocess_one(description):
description = description.lower()
remove_punc = re_tokenizer.tokenize(description) # Remove puntuations
remove_num = [re.sub('[0-9]', '', i) for i in remove_punc] # Remove Numbers
remove_num = [i for i in remove_num if len(i)>0] # Remove empty strings
lemmatized = lemmatization(lemmatizer,remove_num) # Word Lemmatization
remove_stop = remove_stop_words(stopwords_list,lemmatized) # remove stop words
updated_description = ' '.join(remove_stop)
return updated_description
def inference_pvd(description):
description = preprocess_one(description)
description = tokenizer_pvd.texts_to_sequences([description])
description = tf.keras.preprocessing.sequence.pad_sequences(
description,
maxlen=500,
padding='pre'
)
prediction = model_pvd.predict(description)
p1, p2 = prediction
p1 = np.argmax(p1.squeeze())
p2 = np.argmax(p2.squeeze())
violationFlag, violationType = class_dict_violation_flag_rev[p1], class_dict_violation_type_rev[p2]
if (violationFlag == 'no') or (violationType == 'no-violation'):
violationType, violationData = 'no-violation', None
else:
violationData = extract_violation_data(violationType)
return {
"violationType" : f"{violationType}",
"violationData" : violationData
}
app = Flask(__name__)
CORS(app)
qna_ = QnA()
app.config['UPLOAD_FOLDER'] = 'uploads'
@app.route('/pvd', methods=['POST'])
def pvd():
# data = request.files
# file = data['file']
# file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
# file.save(file_path)
data = request.get_json()
story = data['story']
return jsonify(inference_pvd(story))
@app.route('/qna', methods=['POST'])
def qna():
data = request.get_json()
question = data['question']
answer, reference = qna_.extract_answers(question)
return jsonify({
"answer" : f"{answer}",
"reference" : f"{reference}"
})
@app .route('/summary', methods=['POST'])
def summary():
data = request.files
file = data['file']
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(file_path)
text = read_pdf_data(file_path)
summary_facts = inference_summarization(
text,
tokenizer_facts_cases,
tokenizer_facts_summarize,
inference_model_facts,
)
summary_judgements = inference_summarization(
text,
tokenizer_judgements_cases,
tokenizer_judgements_summarize,
inference_model_judgements,
)
return jsonify({
"summary_facts" : f"{summary_facts}",
"summary_judgements" : f"{summary_judgements}"
})
if __name__ == '__main__':
app.run(
debug=True,
host='0.0.0.0',
port=5003
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment