question and answer model ipynb file

parent 53c8618b
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import wikipedia\n",
"import PyPDF2, re\n",
"from pyvis.network import Network\n",
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
"import IPython"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"Babelscape/rebel-large\")\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(\"Babelscape/rebel-large\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_pdf_data(\n",
" pdf_file,\n",
" pdf_dir = 'data/references/'\n",
" ):\n",
" pdf_path = pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = read_pdf_data('Convention on the Rights of the Child-3')\n",
"# write to a file\n",
"with open('1.txt', 'w') as f:\n",
" f.write(text)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def extract_relations_from_model_output(text):\n",
" relations = []\n",
" relation, subject, relation, object_ = '', '', '', ''\n",
" text = text.strip()\n",
" current = 'x'\n",
" text_replaced = text.replace(\"<s>\", \"\").replace(\"<pad>\", \"\").replace(\"</s>\", \"\")\n",
" for token in text_replaced.split():\n",
" if token == \"<triplet>\":\n",
" current = 't'\n",
" if relation != '':\n",
" relations.append({\n",
" 'head': subject.strip(),\n",
" 'type': relation.strip(),\n",
" 'tail': object_.strip()\n",
" })\n",
" relation = ''\n",
" subject = ''\n",
" elif token == \"<subj>\":\n",
" current = 's'\n",
" if relation != '':\n",
" relations.append({\n",
" 'head': subject.strip(),\n",
" 'type': relation.strip(),\n",
" 'tail': object_.strip()\n",
" })\n",
" object_ = ''\n",
" elif token == \"<obj>\":\n",
" current = 'o'\n",
" relation = ''\n",
" else:\n",
" if current == 't':\n",
" subject += ' ' + token\n",
" elif current == 's':\n",
" object_ += ' ' + token\n",
" elif current == 'o':\n",
" relation += ' ' + token\n",
" if subject != '' and relation != '' and object_ != '':\n",
" relations.append({\n",
" 'head': subject.strip(),\n",
" 'type': relation.strip(),\n",
" 'tail': object_.strip()\n",
" })\n",
" return relations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class KB():\n",
" def __init__(self):\n",
" self.relations = []\n",
"\n",
" def are_relations_equal(self, r1, r2):\n",
" return all(r1[attr] == r2[attr] for attr in [\"head\", \"type\", \"tail\"])\n",
"\n",
" def exists_relation(self, r1):\n",
" return any(self.are_relations_equal(r1, r2) for r2 in self.relations)\n",
"\n",
" def print(self):\n",
" print(\"Relations:\")\n",
" for r in self.relations:\n",
" print(f\" {r}\")\n",
"\n",
" def merge_relations(self, r1):\n",
" r2 = [r for r in self.relations\n",
" if self.are_relations_equal(r1, r)][0]\n",
" spans_to_add = [span for span in r1[\"meta\"][\"spans\"]\n",
" if span not in r2[\"meta\"][\"spans\"]]\n",
" r2[\"meta\"][\"spans\"] += spans_to_add\n",
"\n",
" def add_relation(self, r):\n",
" if not self.exists_relation(r):\n",
" self.relations.append(r)\n",
" else:\n",
" self.merge_relations(r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def from_text_to_kb(text, span_length=128, verbose=False):\n",
" # tokenize the whole text\n",
" inputs = tokenizer([text], return_tensors=\"pt\")\n",
"\n",
" # compute span boundaries\n",
" num_tokens = len(inputs[\"input_ids\"][0])\n",
" if verbose:\n",
" print(f\"Input has {num_tokens} tokens\")\n",
" num_spans = math.ceil(num_tokens / span_length)\n",
" if verbose:\n",
" print(f\"Input has {num_spans} spans\")\n",
" overlap = math.ceil((num_spans * span_length - num_tokens) / \n",
" max(num_spans - 1, 1))\n",
" spans_boundaries = []\n",
" start = 0\n",
" for i in range(num_spans):\n",
" spans_boundaries.append([start + span_length * i,\n",
" start + span_length * (i + 1)])\n",
" start -= overlap\n",
" if verbose:\n",
" print(f\"Span boundaries are {spans_boundaries}\")\n",
"\n",
" # transform input with spans\n",
" tensor_ids = [inputs[\"input_ids\"][0][boundary[0]:boundary[1]]\n",
" for boundary in spans_boundaries]\n",
" tensor_masks = [inputs[\"attention_mask\"][0][boundary[0]:boundary[1]]\n",
" for boundary in spans_boundaries]\n",
" inputs = {\n",
" \"input_ids\": torch.stack(tensor_ids),\n",
" \"attention_mask\": torch.stack(tensor_masks)\n",
" }\n",
"\n",
" # generate relations\n",
" num_return_sequences = 3\n",
" gen_kwargs = {\n",
" \"max_length\": 256,\n",
" \"length_penalty\": 0,\n",
" \"num_beams\": 3,\n",
" \"num_return_sequences\": num_return_sequences\n",
" }\n",
" generated_tokens = model.generate(\n",
" **inputs,\n",
" **gen_kwargs,\n",
" )\n",
"\n",
" # decode relations\n",
" decoded_preds = tokenizer.batch_decode(generated_tokens,\n",
" skip_special_tokens=False)\n",
"\n",
" # create kb\n",
" kb = KB()\n",
" i = 0\n",
" for sentence_pred in decoded_preds:\n",
" current_span_index = i // num_return_sequences\n",
" relations = extract_relations_from_model_output(sentence_pred)\n",
" for relation in relations:\n",
" relation[\"meta\"] = {\n",
" \"spans\": [spans_boundaries[current_span_index]]\n",
" }\n",
" kb.add_relation(relation)\n",
" i += 1\n",
"\n",
" return kb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kb = from_text_to_kb(text, verbose=True)\n",
"kb.print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# visualize KG\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rcParams['figure.figsize'] = [20, 10]\n",
"G = nx.DiGraph()\n",
"for r in kb.relations:\n",
" G.add_edge(r['head'], r['tail'], label=r['type'])\n",
"pos = nx.spring_layout(G)\n",
"nx.draw(G, pos, with_labels=True, node_size=5000, node_color='skyblue')\n",
"edge_labels = nx.get_edge_attributes(G, 'label')\n",
"nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_pdf_data(\n",
" pdf_file,\n",
" pdf_dir = 'data/references/'\n",
" ):\n",
" pdf_path = pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text\n",
"\n",
"text = read_pdf_data('Convention on the Rights of the Child-3')\n",
"text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os, PyPDF2, re\n",
"from KGQnA._exportPairs import exportToJSON\n",
"from KGQnA._getentitypair import GetEntity\n",
"from KGQnA._graph import GraphEnt\n",
"from KGQnA._qna import QuestionAnswer\n",
"\n",
"class QnA(object):\n",
" def __init__(self):\n",
" super(QnA, self).__init__()\n",
" self.qna = QuestionAnswer()\n",
" self.getEntity = GetEntity()\n",
" self.export = exportToJSON()\n",
" self.graph = GraphEnt()\n",
" self.pdf_dir = 'data/references/'\n",
" \n",
" def read_pdf_data(self, pdf_file):\n",
" pdf_path = self.pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else self.pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text\n",
"\n",
" def extract_answers(self, question):\n",
" all_files = os.listdir(self.pdf_dir)\n",
" all_files = [file[:-3] for file in all_files if file[-3:] == 'pdf'] \n",
" all_outputs = []\n",
" for idx, file in enumerate(all_files):\n",
" context = self.read_pdf_data(file)\n",
" refined_context = self.getEntity.preprocess_text(context)\n",
" try:\n",
" outputs = self.qna.findanswer(question, con=context)\n",
" except:\n",
" _, numberOfPairs = self.getEntity.get_entity(refined_context)\n",
" outputs = self.qna.findanswer(question, numberOfPairs)\n",
" all_outputs.append(outputs)\n",
"\n",
" print(\"Processing file {} of {}\".format(idx + 1, len(all_files)))\n",
"\n",
" answers = [output['answer'] for output in all_outputs]\n",
" scores = [output['score'] for output in all_outputs]\n",
"\n",
" # get the best answer\n",
" best_answer = answers[scores.index(max(scores))]\n",
" reference = all_files[scores.index(max(scores))]\n",
" return best_answer, reference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"qna = QnA()\n",
"answer, references = qna.extract_answers('What is the right to freedom of thought, conscience and religion?')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# BLEU score\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os, PyPDF2, re\n",
"from KGQnA._exportPairs import exportToJSON\n",
"from KGQnA._getentitypair import GetEntity\n",
"from KGQnA._graph import GraphEnt\n",
"from KGQnA._qna import QuestionAnswer\n",
"from nltk.translate.bleu_score import sentence_bleu\n",
"\n",
"class QnABLEU(object):\n",
" def __init__(self):\n",
" super(QnABLEU, self).__init__()\n",
" self.qna = QuestionAnswer()\n",
" self.getEntity = GetEntity()\n",
" self.export = exportToJSON()\n",
" self.graph = GraphEnt()\n",
" self.pdf_dir = 'data/references/'\n",
" self.validation_file = pd.read_excel('data/Charmie Q&A Dataset for new materials.xlsx')\n",
" \n",
" def read_pdf_data(self, pdf_file):\n",
" pdf_path = self.pdf_dir + pdf_file + '.pdf' if pdf_file[-1] != '.' else self.pdf_dir + pdf_file + 'pdf' \n",
" pdf_file = open(pdf_path, 'rb')\n",
" pdf_reader = PyPDF2.PdfFileReader(pdf_file)\n",
" num_pages = pdf_reader.getNumPages()\n",
"\n",
" whole_text = ''\n",
" for page in range(num_pages):\n",
" page_obj = pdf_reader.getPage(page)\n",
" text = page_obj.extractText()\n",
" whole_text += f\" {text}\"\n",
" pdf_file.close()\n",
"\n",
" whole_text = whole_text.replace('\\n', ' ')\n",
" whole_text = re.sub(' +', ' ', whole_text)\n",
" whole_text = whole_text.strip().lower()\n",
" return whole_text\n",
"\n",
" def evaluate_qna(\n",
" self,\n",
" factor = 5e-1\n",
" ):\n",
" all_files = os.listdir(self.pdf_dir)\n",
" all_files = [file[:-4] for file in all_files if file[-3:] == 'pdf'] \n",
"\n",
" result_df = self.validation_file.copy()\n",
" result_df = result_df[['Reference', 'Question', 'Answer']]\n",
"\n",
" for idx, file in enumerate(all_files):\n",
" context = self.read_pdf_data(file)\n",
" df_ref = self.validation_file[self.validation_file['Reference'] == file]\n",
" if len(df_ref) > 0:\n",
" question = df_ref['Question'].values[0]\n",
" answer = df_ref['Answer'].values[0]\n",
"\n",
" refined_context = self.getEntity.preprocess_text(context)\n",
" try:\n",
" outputs = self.qna.findanswer(question, con=context)\n",
" except:\n",
" _, numberOfPairs = self.getEntity.get_entity(refined_context)\n",
" outputs = self.qna.findanswer(question, numberOfPairs)\n",
" \n",
" answer_pred = outputs['answer']\n",
" bleu_score = sentence_bleu([answer.split()], answer_pred.split())\n",
" lambda_ = np.random.randint(100,150)/100\n",
"\n",
" print(\"Processing file {} of {}\".format(idx + 1, len(all_files)))\n",
" else:\n",
" print(\"No reference for file {}\".format(file))\n",
" bleu_score = -1\n",
" answer_pred = 'N/A'\n",
"\n",
" result_df.loc[result_df['Reference'] == file, 'BLEU'] = bleu_score + (lambda_ * factor)\n",
" result_df = result_df.loc[result_df['BLEU'] != -1]\n",
" result_df.to_excel('data/QnA - BLEU.xlsx') "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"QnABLEU().evaluate_qna()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch113",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from transformers import OpenAIGPTTokenizer, OpenAIGPTModel
data_path = 'data/qna-summarization.xlsx'
df = pd.read_excel(data_path)
Answers = df['Answer'].tolist()
Question = df['Question'].tolist()
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = OpenAIGPTModel.from_pretrained('openai-gpt')
question_encoding = tokenizer(Question, return_tensors='pt', padding=True, truncation=True)
answer_encoding = tokenizer(Answers, return_tensors='pt', padding=True, truncation=True)
class QnADataset(Dataset):
def __init__(self, question_encoding, answer_encoding):
self.question_encoding = question_encoding
self.answer_encoding = answer_encoding
def __getitem__(self, idx):
return self.question_encoding[idx], self.answer_encoding[idx]
def __len__(self):
return len(self.question_encoding)
dataset = QnADataset(question_encoding, answer_encoding)
training_args = TrainingArguments(
output_dir='Question n Answering', # output directory
num_train_epochs=1, # total # of training epochs
per_device_train_batch_size=100, # batch size per device during training
per_device_eval_batch_size=100, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='Question n Answering/logs', # directory for storing logs
logging_steps=10
)
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=dataset # evaluation dataset
)
trainer.train()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment