악성 URL 분류 AI 경진 대회 - 최종 코드
이번 글은 악성 URL 분류 AI 경진 대회의 최종 코드를 살펴보도록 하겠습니다. 지난 EDA를 통해서 URL로부터 특징을 추출하거나 URL을 BERT를 사용하여 처리하는 과정을 거쳤습니다. 그럼 이에 대하여 최종 코드를 살펴보도록 하겠습니다. 코드의 내용은 아래와 같습니다.
1. Data Load
2. 길이 기반 특징 추출
3. 개수 기반 특징 추출
4. 존재 여부 특징 추출
5. 기타 특징 추출
6. Embedding with BERT
7. Hold out & Scale
8. LGBM Classifier
데이콘 링크
: https://dacon.io/competitions/official/236451/overview/description
악성 URL 분류 AI 경진대회 - DACON
분석시각화 대회 코드 공유 게시물은 내용 확인 후 좋아요(투표) 가능합니다.
dacon.io
EDA 과정
: 2025.03.28 - [Personal Projects/Dacon] - [Dacon] 악성 URL 분류 AI 경진대회 (2) - EDA
[Dacon] 악성 URL 분류 AI 경진대회 (2) - EDA
악성 URL 분류 AI 경진대회 - EDA 이번글은 악성 URL 분류 대회에서 수행했던 탐색적 데이터 분석(EDA)에 대한 글입니다. 이번 대회를 통해서 URL에 대해서 알아보며, 기본적으로 URL을 분류하기 위해
muns-da2.tistory.com
Code Review
* Computing Resource : RunPod, A100 PCle, 80GB RAM, 117GB RAM 8 CPU
0. install & import Library
!pip install pandas
!pip install tld
!pip install googlesearch-python
!pip install matplotlib
!pip install scikit-learn
!pip install transformers
!pip install imblearn
!pip install lgbm
!pip install aiohttp
!pip install seaborn
!pip install lightgbm
!pip install numpy==1.24.1
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
import re
from urllib.parse import urlparse
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
from urllib.parse import urlparse, unquote
import string
import hashlib
from lightgbm import LGBMClassifier
from sklearn.metrics import roc_auc_score
import numpy as np
from sklearn.model_selection import train_test_split
import datetime
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
import aiohttp
import time
from scipy.stats import entropy
warnings.filterwarnings(action='ignore')
# 로깅 비활성화
logging.getLogger("whois").setLevel(logging.CRITICAL)
1. Data Load
# 학습/평가 데이터 로드
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')
# '[.]'을 '.'으로 복구
train_df['URL'] = train_df['URL'].str.replace(r'\[\.\]', '.', regex=True)
test_df['URL'] = test_df['URL'].str.replace(r'\[\.\]', '.', regex=True)
def extract_prime_url(url):
if '/' in url:
return url.split('/', 1)[0] # 첫 번째 슬래시 왼쪽 부분
return url # 슬래시가 없으면 전체 URL 반환
def extract_other_domain(url):
if '/' in url:
return url.split('/', 1)[1] # 첫 번째 슬래시 오른쪽 부분
return None
# 새로운 컬럼 추가
train_df['Prime_url'] = train_df['URL'].apply(extract_prime_url)
test_df['Prime_url'] = test_df['URL'].apply(extract_prime_url)
train_df['Other_domain'] = train_df['URL'].apply(extract_other_domain)
test_df['Other_domain'] = test_df['URL'].apply(extract_other_domain)
2. 길이 기반 특징 추출
# # URL 길이
train_df['length'] = train_df['URL'].str.len()
test_df['length'] = test_df['URL'].str.len()
# 최대값과 최소값 구하기
max_length = test_df['length'].max()
min_length = test_df['length'].min()
# 최대 연속 소문자 길이 계산
train_df['max_lowercase_sequence'] = train_df['URL'].apply(lambda x: max([len(seq) for seq in re.findall(r'[a-z]+', x)] or [0]))
test_df['max_lowercase_sequence'] = test_df['URL'].apply(lambda x: max([len(seq) for seq in re.findall(r'[a-z]+', x)] or [0]))
train_df['max_numeric_sequence'] = train_df['URL'].apply(lambda x: max([len(seq) for seq in re.findall(r'\d+', x)] or [0]))
test_df['max_numeric_sequence'] = test_df['URL'].apply(lambda x: max([len(seq) for seq in re.findall(r'\d+', x)] or [0]))
# 최대 연속 대문자 길이 계산
train_df['max_uppercase_sequence'] = train_df['URL'].apply(lambda x: max([len(seq) for seq in re.findall(r'[A-Z]+', x)] or [0]))
test_df['max_uppercase_sequence'] = test_df['URL'].apply(lambda x: max([len(seq) for seq in re.findall(r'[A-Z]+', x)] or [0]))
# 첫 번째 / 이후의 호스트 길이 계산하는 컬럼 추가
def host_length_after_slash(url):
if '/' in url:
# 첫 번째 / 이후의 부분 추출
host_part = url.split('/', 1)[1] # 첫 번째 슬래시 이후의 부분
return len(host_part.split('/')[0]) # 호스트 길이 계산
return 0 # 슬래시가 없으면 0 반환
train_df['Host Length After Slash'] = train_df['URL'].apply(host_length_after_slash)
test_df['Host Length After Slash'] = test_df['URL'].apply(host_length_after_slash)
# 첫 번째 / 이전의 호스트 길이 계산하는 컬럼 추가
def host_length_before_slash(url):
if '/' in url:
# 첫 번째 / 이전의 부분 추출
host_part = url.split('/', 1)[0] # 첫 번째 슬래시 이전의 부분
return len(host_part) # 호스트 길이 계산
return len(url) # 슬래시가 없으면 전체 URL 길이를 반환
train_df['Host Length Before Slash'] = train_df['URL'].apply(host_length_before_slash)
test_df['Host Length Before Slash'] = test_df['URL'].apply(host_length_before_slash)
3. 존재 여부 특징 추출
# 단어 목록
keywords = [
"login", "bank", "secure", "update", "verify",
"account", "password", "security", "transaction",
"sensitive", "confidential", "payment", "access",
"protect", "fraud", "alert", "notify", "register", "dashboard", "profile", "checkout", "cart", "search", "terms", "privacy"
]
# 각 URL에서 단어 포함 개수를 카운트하여 하나의 컬럼에 합산
def count_keywords(url):
return sum(url.count(keyword) for keyword in keywords)
# 모든 URL을 소문자로 변환하여 카운트
train_df['keyword_count'] = train_df['URL'].apply(lambda x: count_keywords(x.lower()))
test_df['keyword_count'] = test_df['URL'].apply(lambda x: count_keywords(x.lower()))
# 단축 URL 서비스 확인 함수
def shortening_service(url):
match = re.search(r'bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|'
r'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|'
r'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|'
r'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|'
r'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|'
r'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|'
r'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|'
r'tr\.im|link\.zip\.net', url)
if match:
return 1 # 단축 URL 서비스가 발견됨
else:
return 0 # 단축 URL 서비스가 없음
# 새로운 컬럼에 단축 URL 여부 추가
train_df['short_url'] = train_df['URL'].apply(shortening_service)
test_df['short_url'] = test_df['URL'].apply(shortening_service)
# 체크할 파일 확장자 목록
extensions = [
".jpg", ".jpeg", ".png", ".gif", ".bmp", # 이미지
".pdf", ".doc", ".docx", ".xls", ".xlsx", # 문서
".mp4", ".avi", ".mov", # 비디오
".mp3", ".wav", # 오디오
".zip", ".rar", # 압축 파일
".tiff", ".tif", # 이미지
".webp", # 이미지
".svg", # 이미지
".ppt", ".pptx", # 문서
".txt", # 문서
".csv", # 문서
".xml", # 문서
".html", ".htm", ".hwpx"# 문서
".mkv", # 비디오
".wmv", # 비디오
".flv", # 비디오
".mpeg", ".mpg", # 비디오
".aac", # 오디오
".flac", # 오디오
".ogg", # 오디오
".7z", # 압축 파일
".tar", # 압축 파일
".gz", # 압축 파일
".bz2", # 압축 파일
".iso", # 기타
".json", # 기타
".md", # 기타
".psd", # 기타
".ai" , # 기타
".lnk", ".vbs"
]
# 파일 확장자 존재 여부 확인 함수 정의
def check_extensions(url):
return any(url.lower().endswith(ext) for ext in extensions)
# 새로운 컬럼 추가: has_extension
train_df['has_extension'] = train_df['URL'].apply(check_extensions)
test_df['has_extension'] = test_df['URL'].apply(check_extensions)
# 체크할 특수 문자 리스트
special_characters = ['-', '=']
# 각 특수 문자의 개수를 카운트하는 label 생성
for char in special_characters:
train_df[f'count_{char}'] = train_df['URL'].apply(lambda x: x.count(char))
test_df[f'count_{char}'] = test_df['URL'].apply(lambda x: x.count(char))
4. 개수 기반 특징 추출
train_df['path_depth'] = train_df['URL'].str.count('/')
test_df['path_depth'] = test_df['URL'].str.count('/')
# 서브도메인 개수
train_df['subdomain_count'] = train_df['URL'].str.split('.').apply(lambda x: len(x) - 2)
test_df['subdomain_count'] = test_df['URL'].str.split('.').apply(lambda x: len(x) - 2)
# URL에서 'www'의 개수 세기
train_df['count-www'] = train_df['URL'].apply(lambda url: url.count('www'))
test_df['count-www'] = test_df['URL'].apply(lambda url: url.count('www'))
# mail
train_df['count-mail'] = train_df['URL'].str.count('mail')
test_df['count-mail'] = test_df['URL'].str.count('mail')
# blog
train_df['count-blog'] = train_df['URL'].str.count('blog')
test_df['count-blog'] = test_df['URL'].str.count('blog')
# 숫자의 개수
train_df['digit_count'] = train_df['URL'].str.count(r'\d')
test_df['digit_count'] = test_df['URL'].str.count(r'\d')
# 소문자 비율 계산
train_df['lowercase_count'] = train_df['URL'].str.count(r'[a-z]')
test_df['lowercase_count'] = test_df['URL'].str.count(r'[a-z]')
# 대문자 비율 계산
train_df['uppercase_count'] = train_df['URL'].str.count(r'[A-Z]')
test_df['uppercase_count'] = test_df['URL'].str.count(r'[A-Z]')
# 문자, 숫자, 특수 문자의 개수를 세는 함수
def count_characters(url):
letters_count = sum(c.isalpha() for c in url) # 문자 수
digits_count = sum(c.isdigit() for c in url) # 숫자 수
special_chars_count = sum(c in string.punctuation for c in url) # 특수 문자 수
return letters_count, digits_count, special_chars_count
# train_df와 test_df에 대해 개수 계산 후 새로운 컬럼에 저장
train_df['letters_count'], train_df['digits_count'], train_df['special_chars_count'] = zip(*train_df['Prime_url'].apply(count_characters))
test_df['letters_count'], test_df['digits_count'], test_df['special_chars_count'] = zip(*test_df['Prime_url'].apply(count_characters))
# 문자, 숫자, 특수 문자의 개수를 세는 함수
def count_characters(text):
if text is None: # None인 경우
return 0, 0, 0 # 모두 0 반환
letters_count = sum(c.isalpha() for c in text) # 문자 수
digits_count = sum(c.isdigit() for c in text) # 숫자 수
special_chars_count = sum(c in string.punctuation for c in text) # 특수 문자 수
return letters_count, digits_count, special_chars_count
# train_df와 test_df에 대해 Other_domain에 대한 개수 계산 후 새로운 컬럼에 저장
train_df['letters_count_other'], train_df['digits_count_other'], train_df['special_chars_count_other'] = zip(*train_df['Other_domain'].apply(count_characters))
test_df['letters_count_other'], test_df['digits_count_other'], test_df['special_chars_count_other'] = zip(*test_df['Other_domain'].apply(count_characters))
# 문자를 제거한 나머지의 개수를 계산하는 함수
def non_alpha_count(url):
non_alpha = re.sub(r'[a-zA-Z]', '', url) # 문자(알파벳)만 제거
return len(non_alpha) if non_alpha else 0
train_df['Non Alpha Count'] = train_df['URL'].apply(non_alpha_count)
test_df['Non Alpha Count'] = test_df['URL'].apply(non_alpha_count)
5. 기타 특징 추출
import pandas as pd
# 주어진 get_url_region 함수
def get_url_region(primary_domain):
ccTLD_to_region = {
".ac": "Ascension Island",
".ad": "Andorra",
".ae": "United Arab Emirates",
".af": "Afghanistan",
".ag": "Antigua and Barbuda",
".ai": "Anguilla",
".al": "Albania",
".am": "Armenia",
".an": "Netherlands Antilles",
".ao": "Angola",
".aq": "Antarctica",
".ar": "Argentina",
".as": "American Samoa",
".at": "Austria",
".au": "Australia",
".aw": "Aruba",
".ax": "Åland Islands",
".az": "Azerbaijan",
".ba": "Bosnia and Herzegovina",
".bb": "Barbados",
".bd": "Bangladesh",
".be": "Belgium",
".bf": "Burkina Faso",
".bg": "Bulgaria",
".bh": "Bahrain",
".bi": "Burundi",
".bj": "Benin",
".bm": "Bermuda",
".bn": "Brunei Darussalam",
".bo": "Bolivia",
".br": "Brazil",
".bs": "Bahamas",
".bt": "Bhutan",
".bv": "Bouvet Island",
".bw": "Botswana",
".by": "Belarus",
".bz": "Belize",
".ca": "Canada",
".cc": "Cocos Islands",
".cd": "Democratic Republic of the Congo",
".cf": "Central African Republic",
".cg": "Republic of the Congo",
".ch": "Switzerland",
".ci": "Côte d'Ivoire",
".ck": "Cook Islands",
".cl": "Chile",
".cm": "Cameroon",
".cn": "China",
".co": "Colombia",
".cr": "Costa Rica",
".cu": "Cuba",
".cv": "Cape Verde",
".cw": "Curaçao",
".cx": "Christmas Island",
".cy": "Cyprus",
".cz": "Czech Republic",
".de": "Germany",
".dj": "Djibouti",
".dk": "Denmark",
".dm": "Dominica",
".do": "Dominican Republic",
".dz": "Algeria",
".ec": "Ecuador",
".ee": "Estonia",
".eg": "Egypt",
".er": "Eritrea",
".es": "Spain",
".et": "Ethiopia",
".eu": "European Union",
".fi": "Finland",
".fj": "Fiji",
".fk": "Falkland Islands",
".fm": "Federated States of Micronesia",
".fo": "Faroe Islands",
".fr": "France",
".ga": "Gabon",
".gb": "United Kingdom",
".gd": "Grenada",
".ge": "Georgia",
".gf": "French Guiana",
".gg": "Guernsey",
".gh": "Ghana",
".gi": "Gibraltar",
".gl": "Greenland",
".gm": "Gambia",
".gn": "Guinea",
".gp": "Guadeloupe",
".gq": "Equatorial Guinea",
".gr": "Greece",
".gs": "South Georgia and the South Sandwich Islands",
".gt": "Guatemala",
".gu": "Guam",
".gw": "Guinea-Bissau",
".gy": "Guyana",
".hk": "Hong Kong",
".hm": "Heard Island and McDonald Islands",
".hn": "Honduras",
".hr": "Croatia",
".ht": "Haiti",
".hu": "Hungary",
".id": "Indonesia",
".ie": "Ireland",
".il": "Israel",
".im": "Isle of Man",
".in": "India",
".io": "British Indian Ocean Territory",
".iq": "Iraq",
".ir": "Iran",
".is": "Iceland",
".it": "Italy",
".je": "Jersey",
".jm": "Jamaica",
".jo": "Jordan",
".jp": "Japan",
".ke": "Kenya",
".kg": "Kyrgyzstan",
".kh": "Cambodia",
".ki": "Kiribati",
".km": "Comoros",
".kn": "Saint Kitts and Nevis",
".kp": "Democratic People's Republic of Korea (North Korea)",
".kr": "Republic of Korea (South Korea)",
".kw": "Kuwait",
".ky": "Cayman Islands",
".kz": "Kazakhstan",
".la": "Laos",
".lb": "Lebanon",
".lc": "Saint Lucia",
".li": "Liechtenstein",
".lk": "Sri Lanka",
".lr": "Liberia",
".ls": "Lesotho",
".lt": "Lithuania",
".lu": "Luxembourg",
".lv": "Latvia",
".ly": "Libya",
".ma": "Morocco",
".mc": "Monaco",
".md": "Moldova",
".me": "Montenegro",
".mf": "Saint Martin (French part)",
".mg": "Madagascar",
".mh": "Marshall Islands",
".mk": "North Macedonia",
".ml": "Mali",
".mm": "Myanmar",
".mn": "Mongolia",
".mo": "Macao",
".mp": "Northern Mariana Islands",
".mq": "Martinique",
".mr": "Mauritania",
".ms": "Montserrat",
".mt": "Malta",
".mu": "Mauritius",
".mv": "Maldives",
".mw": "Malawi",
".mx": "Mexico",
".my": "Malaysia",
".mz": "Mozambique",
".na": "Namibia",
".nc": "New Caledonia",
".ne": "Niger",
".nf": "Norfolk Island",
".ng": "Nigeria",
".ni": "Nicaragua",
".nl": "Netherlands",
".no": "Norway",
".np": "Nepal",
".nr": "Nauru",
".nu": "Niue",
".nz": "New Zealand",
".om": "Oman",
".pa": "Panama",
".pe": "Peru",
".pf": "French Polynesia",
".pg": "Papua New Guinea",
".ph": "Philippines",
".pk": "Pakistan",
".pl": "Poland",
".pm": "Saint Pierre and Miquelon",
".pn": "Pitcairn",
".pr": "Puerto Rico",
".ps": "Palestinian Territory",
".pt": "Portugal",
".pw": "Palau",
".py": "Paraguay",
".qa": "Qatar",
".re": "Réunion",
".ro": "Romania",
".rs": "Serbia",
".ru": "Russia",
".rw": "Rwanda",
".sa": "Saudi Arabia",
".sb": "Solomon Islands",
".sc": "Seychelles",
".sd": "Sudan",
".se": "Sweden",
".sg": "Singapore",
".sh": "Saint Helena",
".si": "Slovenia",
".sj": "Svalbard and Jan Mayen",
".sk": "Slovakia",
".sl": "Sierra Leone",
".sm": "San Marino",
".sn": "Senegal",
".so": "Somalia",
".sr": "Suriname",
".ss": "South Sudan",
".st": "São Tomé and Príncipe",
".sv": "El Salvador",
".sx": "Sint Maarten (Dutch part)",
".sy": "Syria",
".sz": "Eswatini",
".tc": "Turks and Caicos Islands",
".td": "Chad",
".tf": "French Southern Territories",
".tg": "Togo",
".th": "Thailand",
".tj": "Tajikistan",
".tk": "Tokelau",
".tl": "Timor-Leste",
".tm": "Turkmenistan",
".tn": "Tunisia",
".to": "Tonga",
".tr": "Turkey",
".tt": "Trinidad and Tobago",
".tv": "Tuvalu",
".tw": "Taiwan",
".tz": "Tanzania",
".ua": "Ukraine",
".ug": "Uganda",
".uk": "United Kingdom",
".us": "United States",
".uy": "Uruguay",
".uz": "Uzbekistan",
".va": "Vatican City",
".vc": "Saint Vincent and the Grenadines",
".ve": "Venezuela",
".vg": "British Virgin Islands",
".vi": "U.S. Virgin Islands",
".vn": "Vietnam",
".vu": "Vanuatu",
".wf": "Wallis and Futuna",
".ws": "Samoa",
".ye": "Yemen",
".yt": "Mayotte",
".za": "South Africa",
".zm": "Zambia",
".zw": "Zimbabwe"
}
for ccTLD in ccTLD_to_region:
if primary_domain.endswith(ccTLD):
return ccTLD_to_region[ccTLD]
return "ETC"
# Prime_url에서 최상위 도메인 추출하는 함수
def extract_top_level_domain(prime_url):
# '.'로 나누어 마지막 부분을 가져옴
domain_parts = prime_url.split('.')
return '.' + domain_parts[-1] # 최상위 도메인
# Prime_url에서 최상위 도메인 추출 후 지역 찾기
train_df['Top Level Domain'] = train_df['Prime_url'].apply(extract_top_level_domain)
train_df['Region'] = train_df['Top Level Domain'].apply(get_url_region)
test_df['Top Level Domain'] = test_df['Prime_url'].apply(extract_top_level_domain)
test_df['Region'] = test_df['Top Level Domain'].apply(get_url_region)
# LabelEncoder 객체 생성
label_encoder = LabelEncoder()
# Region 컬럼을 레이블 인코딩 (학습 데이터)
train_df['Region_Encoded'] = label_encoder.fit_transform(train_df['Region'])
# 테스트 데이터에 레이블 인코딩 적용 (학습 데이터에서 사용한 인코더 재사용)
test_df['Region_Encoded'] = label_encoder.transform(test_df['Region'])
# 인코딩된 레이블과 원래 레이블 간의 매핑 확인
label_mapping = dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))
6. Embedding with BERT
import torch
import numpy as np
from transformers import BertTokenizer, BertModel
from tqdm import tqdm # tqdm 라이브러리 임포트
# GPU가 사용 가능한지 확인하고 모델을 GPU로 이동
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BERT 모델과 토크나이저 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').to(device)
# train_df의 URL 임베딩 생성 (진행율 표시)
url_embeddings_train = []
for url in tqdm(train_df['URL'], desc="Processing train URLs"):
tokens = tokenizer.encode(url, add_special_tokens=True, max_length=512, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(tokens)
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() # 결과를 CPU로 이동
url_embeddings_train.append(embeddings)
url_embeddings_train = torch.tensor(url_embeddings_train)
# test_df의 URL 임베딩 생성 (진행율 표시)
url_embeddings_test = []
for url in tqdm(test_df['URL'], desc="Processing test URLs"):
tokens = tokenizer.encode(url, add_special_tokens=True, max_length=512, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(tokens)
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() # 결과를 CPU로 이동
url_embeddings_test.append(embeddings)
url_embeddings_test = torch.tensor(url_embeddings_test)
7. Hold out & Scale
# X와 y 재정의
X = train_df.drop(columns=['ID', 'label', 'URL', 'Prime_url','Other_domain','Top Level Domain', 'Region'])
y = train_df['label']
# 테스트 데이터 처리
X_test = test_df.drop(columns=['ID', 'URL', 'Prime_url','Other_domain','Top Level Domain', 'Region'])
from sklearn.preprocessing import StandardScaler
feature_cols = X.columns.tolist() # X의 모든 컬럼을 feature_cols로 사용
# 스케일러 초기화
scaler = StandardScaler()
# 학습 데이터 피처 스케일링 (원래 DataFrame 유지)
train_df_scaled = X.copy() # 기존 train_df 유지
train_df_scaled[feature_cols] = scaler.fit_transform(X[feature_cols])
# 테스트 데이터 피처 스케일링 (학습 데이터의 스케일러 사용)
test_df_scaled = X_test.copy() # 기존 test_df 유지
test_df_scaled[feature_cols] = scaler.transform(X_test[feature_cols])
other_numeric_data_tensor = torch.tensor(train_df_scaled.values)
concatenated_features = torch.cat((url_embeddings_train, other_numeric_data_tensor), dim=1)
# PyTorch 텐서 변환 (테스트 데이터)
other_numeric_data_test = torch.tensor(test_df_scaled.values, dtype=torch.float32)
concatenated_features_test = torch.cat((url_embeddings_test, other_numeric_data_test), dim=1)
X_train, X_val, y_train, y_val = train_test_split(concatenated_features, y, stratify = y, shuffle=True, test_size=0.05, random_state=42)
8. LGBM Classifier
# LightGBM Classifier 초기화
model_lgb = LGBMClassifier(
objective='binary',
metric='binary_logloss',
learning_rate=0.03,
num_leaves=60,
n_estimators=5000,
random_state=42
)
# 모델 훈련
model_lgb.fit(X_val, y_val)
# 검증 데이터 예측 및 ROC-AUC 계산
y_val_pred_prob = model_lgb.predict_proba(X_val)[:, 1]
auc = roc_auc_score(y_val, y_val_pred_prob)
print(f"Validation ROC-AUC: {auc:.4f}")
# # Validation ROC-AUC: 0.9529
from sklearn.metrics import roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
# 예측 클래스 생성 (0 또는 1)
y_val_pred = model_lgb.predict(X_val)
# 컨퓨전 매트릭스 계산
cm = confusion_matrix(y_val, y_val_pred)
# 컨퓨전 매트릭스 시각화
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=model_lgb.classes_)
disp.plot(cmap='Blues')
plt.title('Confusion Matrix')
plt.show()
import numpy as np
# 평가 데이터 추론
# 단일 모델의 예측 확률 계산
test_probabilities = model_lgb.predict_proba(concatenated_features_test)[:, 1] # 악성 URL(1)일 확률
print('Inference Done.')
■ 마무리
여기까지 악성 URL 분류에 대한 최종 코드를 살펴보았습니다. 이번 대회를 통해서 URL 구조 그리고 악성 URL을 구분하기 위한 다양한 방식을 시도해보는 시간이었습니다. 언제나 새로운 데이터를 배우고 이를 공부하는 과정에서 또 많은 것을 배우게 되었습니다. 다음 4월에도 이미지 데이터와 함께 금융 데이터 관련 대회가 준비되어 있습니다. 이 또한 잘 준비하면서 알게 되었던 것을 복습하고 새로운 기법들을 배울 수 있었으면 좋겠습니다.
■ 진행 중 대회
대회 1 : https://dacon.io/competitions/official/236460/overview/description
신용카드 고객 세그먼트 분류 AI 경진대회 - DACON
분석시각화 대회 코드 공유 게시물은 내용 확인 후 좋아요(투표) 가능합니다.
dacon.io
대회 2 : https://dacon.io/competitions/official/236459/overview/description
이미지 분류 해커톤: 데이터 속 아이콘의 종류를 맞혀라! - DACON
분석시각화 대회 코드 공유 게시물은 내용 확인 후 좋아요(투표) 가능합니다.
dacon.io
'Personal Projects > Dacon' 카테고리의 다른 글
[Dacon] 악성 URL 분류 AI 경진대회 (2) - EDA (0) | 2025.03.31 |
---|---|
[Dacon] 악성 URL 분류 AI 경진대회 (1) (0) | 2025.03.31 |
[Dacon] 채무 불이행 여부 예측 해커톤 (3) - Code (0) | 2025.03.31 |
[Dacon] 채무 불이행 여부 예측 해커톤 (2) - EDA (0) | 2025.03.31 |
[Dacon] 채무 불이행 여부 예측 해커톤 (1) - 후기 (0) | 2025.03.31 |