인공지능/딥러닝

[딥러닝 파이토치 교과서] 7장 시계열 분석 Colab torchtext 오류 해결법

M.랄라 2023. 6. 10. 21:19

딥러닝 파이토치 교과서의 경우 torchtext가 0.8.0, 0.9.0 혹은 0.10.0 중 하나로 코드를 작성하신 것 같다.

그러나 안타깝게도 colab은 torchtext의 0.8.0, 0.9.0, 0.10.0 모두 지원하지 않는 다는 점이 문제다. (2023.06.07 기준)

이것 때문에 시계열 데이터 실습을 못하는건 너무 아깝다고 생각하여 375p [코드 7-4] ~ 379p [코드 7-10] 까지를 대신할 수 있는 코드를 작성하였다. 똑같이 IMDB데이터를 사용하였으므로 7-12부터는 책에있는 코드를 그대로 사용해도 가능할 것 같다.

 

1. https://www.kaggle.com/datasets/atulanandjha/imdb-50k-movie-reviews-test-your-bert?select=test.csv 에 접속해 imdb train.csv, test.csv를 다운받는다.

 

2. [코드 7-4] ~ [코드7-9] 내용

import pandas as pd 
import csv

train_file_path = 'train.csv경로'
test_file_path = 'test.csv경로'

train_df = pd.read_csv(train_file_path)
train_df = train_df.rename(columns={'sentiment': 'label'})
train_df = train_df.reset_index()

test_df = pd.read_csv(test_file_path)
test_df = test_df.rename(columns={'sentiment': 'label'})
test_df = test_df.reset_index()

train_data = []
test_data = []

# train_data 초기화
for index, line in train_df.iterrows():
    original_dict = {
        'text': [],
        'label' : ""
    }

    if (len(line) < 2):
      continue

    original_dict['text'] = line['text'].split(' ')
    original_dict['label'] = line['label']
    train_data.append(original_dict)

# test_data 초기화
for index, line in test_df.iterrows():
    original_dict = {
        'text': [],
        'label' : ""
    }

    if (len(line) < 2):
      continue

    original_dict['text'] = line['text'].split(' ')
    original_dict['label'] = line['label']
    test_data.append(original_dict)
    
import string

for example in train_data:
  text = [x.lower() for x in example['text']]
  text = [x.replace('<br','')for x in text]
  text = [''.join(c for c in s if c not in string.punctuation) for s in text]
  text = [s for s in text if s]
  example['text'] = text
  
import random
from sklearn.model_selection import train_test_split

train_data, valid_data = train_test_split(train_data, random_state=random.seed(0), test_size=0.2)
print(f'Number of training examples : {len(train_data)}')
print(f'Number of valid_data examples : {len(valid_data)}')
print(f'Number of test examples : {len(test_data)}')

위의 코드를 실행하면 아래와 같이 책과 똑같은 결과가 나온다 !

Number of training examples : 20000

Number of valid_data examples : 5000

Number of test examples : 25000