### ### # handwriting numbers OCR based on svm. # data from "Machine Learning in Action" chapter 02. # 基于SVM的手写数字的识别程序 # 数据:采用了《Machine Learning in Action》第二章的数据 ### ###
definit_models(self): for i in range(0, 10): m = model() for j in range(len(self.samples)): m.a.append(0) self.models.append(m)
definit_cache_kernel(self): i = 0 for mi in self.samples: print i self.cache_kernel.append([]) j = 0 for mj in self.samples: if i > j: self.cache_kernel[i].append(self.cache_kernel[j][i]) else: self.cache_kernel[i].append(kernel(mi,mj)) j += 1 i += 1
classimage: def__init__(self): self.data = [] self.num = 0 self.label = [] self.fn = "" defprintself(self): print"data" for line in self.data: print line print"num", self.num print"label", self.label[gv.cur_mno] print"fn", self.fn
# global variables gv = GV()
defparse_image(path): img_map = [] fp = open(path, "r") for line in fp: line = line[:-2] img_map.append(line) return img_map
# load samples and tests defloaddata(dirpath, col): files = os.listdir(dirpath) for file in files: img = image() img.data = parse_image(dirpath + file) img.num = int(file[0]) img.fn = file col.append(img)
###### # 高斯核函数 ###### defkernel_RBF(mj, mi): if mj == mi: return math.exp(0) dlt = gv.RBF_dlt ret = 0.0 for i in range(len(mj.data)): for j in range(len(mj.data[i])): ret += math.pow(int(mj.data[i][j]) - int(mi.data[i][j]), 2) ret = math.exp(-ret/(2*dlt*dlt)) return ret
###### # 线性 ###### defkernel_linear(mj, mi): ret = 0.0 for i in range(len(mj.data)): for j in range(len(mj.data[i])): ret += int(mj.data[i][j]) * int(mi.data[i][j]) return ret
# g(x) defpredict(m): pred = 0.0 for j in range(len(gv.samples)): if gv.models[gv.cur_mno].a[j] != 0: pred += gv.models[gv.cur_mno].a[j] * gv.samples[j].label[gv.cur_mno] * kernel(gv.samples[j],m) pred += gv.models[gv.cur_mno].b return pred
# the same as predict(m), only with different parmaters defpredict_train(i): pred = 0.0 for j in range(len(gv.samples)): if gv.models[gv.cur_mno].a[j] != 0: pred += gv.models[gv.cur_mno].a[j] * gv.samples[j].label[gv.cur_mno] * gv.cache_kernel[j][i] pred += gv.models[gv.cur_mno].b return pred
definit_predict_diff_real_dict(): gv.diff_dict = [] for i in range(len(gv.samples)): gv.diff_dict.append(predict_diff_real(i))
defupdate_diff_dict(i, new_ai, j, new_bj, new_b): for idx in range(len(gv.samples)): # 原来的函数 # gv.diff_dict[idx] = predict_diff_real(idx) # 有优化后的 gv.diff_dict[idx] = predict_diff_real_optimized(idx, i, new_ai, j, new_bj, new_b)
defupdate_samples_label(num): for img in gv.samples: if img.num == num: img.label.append(1) else: img.label.append(-1)
###### # svmocr train # 基于算法SMO # T: tolerance 误差容忍度(精度) # times: 迭代次数 # C: 惩罚系数 # Mno: 模型序号0到9 # step: aj移动的最小步长 ###### defSVM_SMO_train(T, times, C, Mno, step): time = 0 gv.cur_mno = Mno update_samples_label(Mno) init_predict_diff_real_dict() updated = True while time < times and updated: updated = False time += 1 for i in range(len(gv.samples)): ai = gv.models[gv.cur_mno].a[i] Ei = gv.diff_dict[i]
# agaist the KKT if (gv.samples[i].label[gv.cur_mno] * Ei < -T and ai < C) or (gv.samples[i].label[gv.cur_mno] * Ei > T and ai > 0): for j in range(len(gv.samples)): if j == i: continue kii = gv.cache_kernel[i][i] kjj = gv.cache_kernel[j][j] kji = kij = gv.cache_kernel[i][j] eta = kii + kjj - 2 * kij if eta <= 0: continue new_aj = gv.models[gv.cur_mno].a[j] + gv.samples[j].label[gv.cur_mno] * (gv.diff_dict[i] - gv.diff_dict[j]) / eta # f 7.106 L = 0.0 H = 0.0 a1_old = gv.models[gv.cur_mno].a[i] a2_old = gv.models[gv.cur_mno].a[j] if gv.samples[i].label[gv.cur_mno] == gv.samples[j].label[gv.cur_mno]: L = max(0, a2_old + a1_old - C) H = min(C, a2_old + a1_old) else: L = max(0, a2_old - a1_old) H = min(C, C + a2_old - a1_old) if new_aj > H: new_aj = H if new_aj < L: new_aj = L if abs(a2_old - new_aj) < step: print"j = %d, is not moving enough" % j continue
# 测试数据 deftest(): recog = 0 recog_correct = 0 for img in gv.tests: print"test for", img.fn for mno in range(10): gv.cur_mno = mno if predict(img) > 0: print mno print img.fn recog += 1 if mno == int(img.fn[0]): recog_correct += 1 break print"recog:", recog print"recog_correct:", recog_correct print"total:", len(gv.tests)
defsave_models(): for i in range(10): fn = open("models/" + str(i) + "_a.model", "w") for ai in gv.models[i].a: fn.write(str(ai)) fn.write('\n') fn.close() fn = open("models/" + str(i) + "_b.model", "w") fn.write(str(gv.models[i].b)) fn.close()
defload_models(): for i in range(10): fn = open("models/" + str(i) + "_a.model", "r") j = 0 for line in fn: gv.models[i].a[j] = float(line) j += 1 fn.close() fn = open("models/" + str(i) + "_b.model", "r") gv.models[i].b = float(fn.readline()) fn.close()
if __name__ == "__main__": training = True loaddata("trainingDigits/", gv.samples) loaddata("testDigits/", gv.tests) print len(gv.samples) print len(gv.tests)
if training == True: gv.init_cache_kernel() gv.init_models()
print"init_models done"
T = 0.0001 C = 10 step = 0.0001 gv.RBF_dlt = 8 if training == True: for i in range(10): print"traning model no:", i SVM_SMO_train(T, 100, C, i, step) save_models() else: load_models() for i in range(10): update_samples_label(i) test()