add law_ai_qna.py file

parent d9c8009c
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from transformers import OpenAIGPTTokenizer, OpenAIGPTModel
data_path = 'data/qna-summarization.xlsx'
df = pd.read_excel(data_path)
Answers = df['Answer'].tolist()
Question = df['Question'].tolist()
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = OpenAIGPTModel.from_pretrained('openai-gpt')
question_encoding = tokenizer(Question, return_tensors='pt', padding=True, truncation=True)
answer_encoding = tokenizer(Answers, return_tensors='pt', padding=True, truncation=True)
class QnADataset(Dataset):
def __init__(self, question_encoding, answer_encoding):
self.question_encoding = question_encoding
self.answer_encoding = answer_encoding
def __getitem__(self, idx):
return self.question_encoding[idx], self.answer_encoding[idx]
def __len__(self):
return len(self.question_encoding)
dataset = QnADataset(question_encoding, answer_encoding)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment