ecsimsw
Scikit-learn / SVM 본문
Scikit-learn library
- svm의 SVC 알고리즘으로 학습시키고, 테스트 데이터를 입력하여 정확도 확인
- ex1) XOR 규칙을 알려주지 않은 상황에서 데이터 변수와 정답 레이블을 입력하는 것으로 XOR을 학습한 것이다.
<<<example1>>> find rule of result (xor)
from sklearn import svm
data=[
[0,0,0],
[0,1,1],
[1,0,1],
[1,1,0]
]
learning_data = []
learning_result = []
for row in data:
x=row[0]
y=row[1]
r=row[2]
learning_data.append([x,y])
learning_result.append(r)
clf = svm.SVC()
clf.fit(learning_data,learning_result)
pre = clf.predict([[1,1],[1,0],[0,1]])
for lr in pre:
print(lr)
<<example 1-2>>> Find rule of result / Using Data Frame
import pandas
from sklearn import svm
data=[
[0,0,0],
[0,1,1],
[1,0,1],
[1,1,0]
]
dataFrame = pandas.DataFrame(data)
learning_data = dataFrame.loc[:,0:1]
learning_result = dataFrame.loc[:,2]
clf = svm.SVC()
clf.fit(learning_data,learning_result)
pre = clf.predict([[1,1],[1,0],[0,1]])
for lr in pre:
print(lr)
- ex2) 서로 다른 언어의 기사를 데이터셋으로, 언어별 각 알파벳의 사용 비중을 변수 데이터로 하고, 언어의 종류를 레이블로 학습하여 새로운 기사가 대입됐을 때, 어느 언어로 쓰였는지 분류하였다.
<<<example 2>>> language classification
import sys
from sklearn import svm, metrics
import urllib
import os.path, glob
import matplotlib.pyplot as plt
import pandas as pd
path_test ="./test/"
path_data ="./data/"
# check frequency of using each alphabet
def checkFreq(fname,path):
lang = fname.split('-')[0]
with open(path+fname+'.txt',mode="r", encoding="utf-8") as f:
text = f.read()
text.lower()
cnt = [0 for n in range(0,26)]
for char in text:
if ord('a')<= ord(char) <=ord('z') :
#print(ord(char)-ord('a'))
cnt[ord(char)-ord('a')] += 1
freq = list(map(lambda char: char/sum(cnt), cnt))
return (freq,lang)
# get data files
data_freq=[]
data_lang=[]
files_data = glob.glob(path_data+'*.txt')
for f in files_data:
fname= f.split('\\')[1].split('.txt')[0]
data = checkFreq(fname,path_data)
data_freq.append(data[0])
data_lang.append(data[1])
# learning
clf = svm.SVC()
clf.fit(data_freq,data_lang)
# get test files
test_freq=[]
test_lang=[]
files_test = glob.glob(path_test+'*.txt')
for f in files_test:
fname= f.split('\\')[1].split('.txt')[0]
data = checkFreq(fname,path_test)
test_freq.append(data[0])
test_lang.append(data[1])
# predict & report
predict = clf.predict(test_freq)
ac_score = metrics.accuracy_score(test_lang,predict)
cl_report = metrics.classification_report(test_lang, predict)
print("score : ",ac_score)
print("report : ")
print(cl_report)
# print frequency graph
graph_dic={}
for i in range(0,len(data_lang)):
if not(data_lang[i] in graph_dic):
graph_dic[data_lang[i]] = data_freq[i]
ascii_list = [chr(n) for n in range(ord('a'),ord('z')+1)]
df=pd.DataFrame(graph_dic,index = ascii_list)
plt.style.use('ggplot')
df.plot(kind='bar', subplots=True, ylim=(0,0.15))
plt.savefig(path_test+"test graph")
'Machine Learning' 카테고리의 다른 글
Cross validation / 교차 검증 / 모델 평가 (0) | 2020.01.29 |
---|---|
Support Vector Machine / Random Forest / 언어 구분 학습 (0) | 2020.01.29 |
Crawling / Scraping / 구글 이미지 크롤러 (0) | 2020.01.19 |
Machine Learning / Generic Algorithm (0) | 2019.11.23 |
Teachable machine / teachable machine 사용법 / 구글 티쳐블 머신 (0) | 2019.11.23 |
Comments