## functions and script for Clustering with Dirichlet Process(sort of probability process)
## written by DongHyuk.yi
import numpy as np
from numpy.random import multinomial as mnrand
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import warnings
warnings.filterwarnings('ignore') # for warning off
np.seterr(all='ignore') # for warning off
## function for 'PDF of Multivariate-normal distribution'
# In sage, couldn't use 'mvnorm module in Scipy.stats', so define alternative function
def mvnpdf(x,mu,cov):
diff = mu - x
diff = np.reshape(diff,(1,2))
diff_trans = np.matrix.transpose(diff)
diff_trans = matrix(diff_trans) # for SAGE
diff = matrix(diff)
cov_inv = np.linalg.inv(cov)
cov_inv = matrix(cov_inv) # for SAGE
#multiply1 = np.matmul(diff, cov_inv)
multiply1 = diff*cov_inv
#high_part = np.exp(-0.5*np.matmul(multiply1, diff_trans))
high_part = -0.5 * multiply1 * diff_trans # for SAGE
phi = np.pi
#low_part = np.sqrt(det(2*phi*cov))
low_part = np.sqrt(det(matrix(2*phi*cov))) # For SAGE
pdf = high_part/low_part
return pdf
## function for conversion "sigma list" to "covariance matrix"
def calc_cov(sigma):
cov = np.diag((sigma))
return cov
## Main function
def DP(data, alpha, sigma, init_cluster, iteration=10):
## data = raw data, alpha = concentration paramter, sigma = likely, variation of DP, init_cluster = initial number of cluster, iteration = number of whole process(default : 10000)
print('function message : start')
# set value with input data
data = np.array(data);
nr_data = data.shape[0]
dp_alpha = alpha;
dp_cov = calc_cov(sigma);
nr_cluster = init_cluster
## GMM with DP
# 1. set arbitrary cluster for each data
dp_cluster = np.random.randint(low=1, high=nr_cluster + 1, size=(nr_data))
count = 0
# calc initial centers of each cluster
for ci in np.unique(dp_cluster):
cluster_index = np.nonzero(dp_cluster == ci)
center_pos = np.mean(data[cluster_index], axis=0)
if count == 0:
dp_center = center_pos
else:
dp_center = np.vstack([dp_center, center_pos])
count += 1
p = np.zeros(1)
lik_list = np.zeros(1)
for iter in range(iteration):
print('%d iteration is doing' % iter)
for i in range(0, nr_data):
cur_point = data[i]
p = np.zeros(nr_cluster + 1)
lik_list = np.zeros(nr_cluster)
for j in range(0, nr_cluster):
## Do Polya urn sequance(process)
i_cluster = dp_cluster[i] # pick cluster index sequantly
n_j = np.sum(np.nonzero(dp_cluster == (j + 1))) # count number of picked cluster
if i_cluster == (j + 1): # if picked cluster is the same as current cluster
n_j = (n_j - 1) # if picked cluster is the same as current cluster
## Doing Chinese restaurant process
# calculate probability of previous cluster
if np.isnan(dp_center[j].any()):
lik = 0
else:
lik = mvnpdf(cur_point, dp_center[j], dp_cov)
cur_p = lik * n_j / (nr_data - 1 + dp_alpha)
p[j] = cur_p
lik_list[j] = lik
## Doing Chinese restaurant process(keep going)
# calculate probability of making of new-cluster
lik_new = np.max(lik_list)
new_p = lik_new * dp_alpha / (nr_data - 1 + dp_alpha)
p[j + 1] = new_p
# (1) normalize all probability for clusters and (2) get index randomly(with multinomial distribution)
p_norm = p / np.sum(p)
nr_p = len(p_norm)
z = np.nonzero(mnrand(1, np.reshape(p_norm, len(p_norm))) == 1)
try: # current Deburging part, it will be fixed
z = int(z[0])
except:
continue
# decision wheter making new cluster or not
if (z + 1) <= nr_cluster:
dp_cluster[i] = z + 1
else:
dp_cluster[i] = z + 1
nr_cluster += 1 # statement of making new-cluster
# update information of cluster
del dp_center
for ci in range(nr_cluster):
cluster_index = np.nonzero(dp_cluster == (ci + 1))
filter_data = data[cluster_index]
try:
center_pos = np.mean(filter_data, axis=0)
except:
continue
if ci == 0:
dp_center = center_pos
else:
dp_center = np.vstack((dp_center, center_pos))
print('number of current clusters : %d' % len(np.unique(dp_cluster)))
print('function message : finish')
print('\n')
# result is (1) center of each cluster, (2) cluster index correspond to each data
return dp_center, dp_cluster
# end of 'DP' function
## execution function
## number of whole process : [length of data] x iteration
# load data
data = np.loadtxt(DATA + 'sample.csv', delimiter=',')
# set 'hyper parameter' of Dirichlet Process
alpha = 50
sigma = [0.5, 0.5]
# set other value
init_cluster = 4 # user can input any other integer bigger than 1(check code is needed)
iteration = 100
# call 'DP' function
# dp_center : center-point of each cluster
# dp_cluster = cluster label of each data
dp_center, dp_cluster = DP(data, alpha, sigma, init_cluster, iteration)
## print result of cluster (1st : raw data, 2nd : center of each cluster, 3rd : cluster label of each data)
print('showing raw data(2-dimensional)'); print(data); print('\n'); print('\n')
print('center point of each cluster'); print(dp_center); print('\n'); print('\n')
print('given cluster of each data'); print(dp_cluster)
point(data)
|
WARNING: Output truncated!
full_output.txt
function message : start
0 iteration is doing
number of current clusters : 4
1 iteration is doing
number of current clusters : 4
2 iteration is doing
number of current clusters : 4
3 iteration is doing
number of current clusters : 4
4 iteration is doing
number of current clusters : 4
5 iteration is doing
number of current clusters : 4
6 iteration is doing
number of current clusters : 4
7 iteration is doing
number of current clusters : 4
8 iteration is doing
number of current clusters : 4
9 iteration is doing
number of current clusters : 4
10 iteration is doing
number of current clusters : 4
11 iteration is doing
number of current clusters : 4
12 iteration is doing
number of current clusters : 4
13 iteration is doing
number of current clusters : 4
14 iteration is doing
number of current clusters : 4
15 iteration is doing
number of current clusters : 4
16 iteration is doing
number of current clusters : 5
17 iteration is doing
number of current clusters : 5
18 iteration is doing
number of current clusters : 4
19 iteration is doing
number of current clusters : 4
20 iteration is doing
number of current clusters : 4
21 iteration is doing
number of current clusters : 4
22 iteration is doing
number of current clusters : 4
23 iteration is doing
number of current clusters : 4
24 iteration is doing
number of current clusters : 4
25 iteration is doing
number of current clusters : 4
26 iteration is doing
number of current clusters : 4
27 iteration is doing
number of current clusters : 4
28 iteration is doing
number of current clusters : 4
...
[ 6.0713 7.6785 ]
[ 7.5075 7.5547 ]
[ 5.9073 8.6857 ]
[ 5.5931 8.4311 ]
[ 2.6977 5.9013 ]
[ 4.2878 7.2245 ]
[ 2.4675 5.8667 ]
[ 2.5269 6.9344 ]
[ 2.9182 6.9404 ]
[ 2.2525 6.5458 ]
[ 2.7593 7.9806 ]
[ 2.4569 7.8856 ]
[ 2.6884 7.2995 ]
[ 3.3518 6.5285 ]
[ 3.1219 6.7108 ]
[ 1.5529 6.8987 ]
[ 3.8122 7.4315 ]
[ 2.3163 5.9903 ]
[ 2.65 7.0086 ]
[ 2.6037 5.469 ]
[ 1.67 7.2917 ]
[ 2.2802 7.7611 ]
[ 2.9473 7.3843 ]
[ 2.6626 6.5118 ]
[ 2.9231 8.0081 ]
[ 2.453 6.2581 ]
[ 2.7578 7.5046 ]
[ 3.7104 7.0077 ]
[ 2.3597 8.2589 ]
[ 1.4814 6.5148 ]
[ 3.3948 6.5937 ]
[ 2.1707 6.8331 ]
[ 2.6787 6.9252 ]
[ 2.8403 6.5456 ]]
center point of each cluster
[[ 5.39538976 5.67203684]
[ 6.19154304 5.40182625]
[ 5.56969799 5.85872367]
[ 5.48545157 5.75329946]
[ nan nan]]
given cluster of each data
[4 3 4 2 1 4 4 3 3 4 4 4 1 4 2 4 1 3 4 3 4 4 4 3 2 4 3 3 2 3 4 1 4 2 1 3
4
3 2 3 1 3 3 4 4 1 4 4 4 3 4 4 4 4 4 3 3 4 1 4 4 4 4 4 4 3 2 4 3 4 4 4 4
4
1 4 4 4 4 4 4 3 4 4 2 3 4 3 3 4 4 4 3 4 4 4 1 4 1 2 3 3 1 3 4 4 4 4 4 4
2
3 4 3 4 2 4 1 3 4 4 4 2 3 3 4 1 3 3 4 2 1 1 1 4 1 3 2 1 4 3 4 4 1 3 2 1
4
3 4 3 4 1 4 4 2 4 2 4 2 4 3 3 4 3 4 3 1 4 3 3 3 3 4 4 3 4 3 1 4 4 3 4 4
4
4 4 4 4 3 4 3 1 4 4 4 4 4 3 4 4 3 2 4 4 4 1 3 4 3 4 4 4 3 3 1 4 3 4 3 3
4
2 2 1 4 4 4 4 3 4 4 4 4 4 4 1 4 4 4 4 3 2 4 3 1 1 4 4 4 4 2 4 4 4 4 3 2
4
4 3 4 4 4 4 4 1 3 4 1 3 1 3 4 4 4 4 4 3 4 3 4 3 1 4 1 4 3 4 4 1 3 3 4 4
3
1 3 3 3]

WARNING: Output truncated!
full_output.txt
function message : start
0 iteration is doing
number of current clusters : 4
1 iteration is doing
number of current clusters : 4
2 iteration is doing
number of current clusters : 4
3 iteration is doing
number of current clusters : 4
4 iteration is doing
number of current clusters : 4
5 iteration is doing
number of current clusters : 4
6 iteration is doing
number of current clusters : 4
7 iteration is doing
number of current clusters : 4
8 iteration is doing
number of current clusters : 4
9 iteration is doing
number of current clusters : 4
10 iteration is doing
number of current clusters : 4
11 iteration is doing
number of current clusters : 4
12 iteration is doing
number of current clusters : 4
13 iteration is doing
number of current clusters : 4
14 iteration is doing
number of current clusters : 4
15 iteration is doing
number of current clusters : 4
16 iteration is doing
number of current clusters : 5
17 iteration is doing
number of current clusters : 5
18 iteration is doing
number of current clusters : 4
19 iteration is doing
number of current clusters : 4
20 iteration is doing
number of current clusters : 4
21 iteration is doing
number of current clusters : 4
22 iteration is doing
number of current clusters : 4
23 iteration is doing
number of current clusters : 4
24 iteration is doing
number of current clusters : 4
25 iteration is doing
number of current clusters : 4
26 iteration is doing
number of current clusters : 4
27 iteration is doing
number of current clusters : 4
28 iteration is doing
number of current clusters : 4
...
[ 6.0713 7.6785 ]
[ 7.5075 7.5547 ]
[ 5.9073 8.6857 ]
[ 5.5931 8.4311 ]
[ 2.6977 5.9013 ]
[ 4.2878 7.2245 ]
[ 2.4675 5.8667 ]
[ 2.5269 6.9344 ]
[ 2.9182 6.9404 ]
[ 2.2525 6.5458 ]
[ 2.7593 7.9806 ]
[ 2.4569 7.8856 ]
[ 2.6884 7.2995 ]
[ 3.3518 6.5285 ]
[ 3.1219 6.7108 ]
[ 1.5529 6.8987 ]
[ 3.8122 7.4315 ]
[ 2.3163 5.9903 ]
[ 2.65 7.0086 ]
[ 2.6037 5.469 ]
[ 1.67 7.2917 ]
[ 2.2802 7.7611 ]
[ 2.9473 7.3843 ]
[ 2.6626 6.5118 ]
[ 2.9231 8.0081 ]
[ 2.453 6.2581 ]
[ 2.7578 7.5046 ]
[ 3.7104 7.0077 ]
[ 2.3597 8.2589 ]
[ 1.4814 6.5148 ]
[ 3.3948 6.5937 ]
[ 2.1707 6.8331 ]
[ 2.6787 6.9252 ]
[ 2.8403 6.5456 ]]
center point of each cluster
[[ 5.39538976 5.67203684]
[ 6.19154304 5.40182625]
[ 5.56969799 5.85872367]
[ 5.48545157 5.75329946]
[ nan nan]]
given cluster of each data
[4 3 4 2 1 4 4 3 3 4 4 4 1 4 2 4 1 3 4 3 4 4 4 3 2 4 3 3 2 3 4 1 4 2 1 3 4
3 2 3 1 3 3 4 4 1 4 4 4 3 4 4 4 4 4 3 3 4 1 4 4 4 4 4 4 3 2 4 3 4 4 4 4 4
1 4 4 4 4 4 4 3 4 4 2 3 4 3 3 4 4 4 3 4 4 4 1 4 1 2 3 3 1 3 4 4 4 4 4 4 2
3 4 3 4 2 4 1 3 4 4 4 2 3 3 4 1 3 3 4 2 1 1 1 4 1 3 2 1 4 3 4 4 1 3 2 1 4
3 4 3 4 1 4 4 2 4 2 4 2 4 3 3 4 3 4 3 1 4 3 3 3 3 4 4 3 4 3 1 4 4 3 4 4 4
4 4 4 4 3 4 3 1 4 4 4 4 4 3 4 4 3 2 4 4 4 1 3 4 3 4 4 4 3 3 1 4 3 4 3 3 4
2 2 1 4 4 4 4 3 4 4 4 4 4 4 1 4 4 4 4 3 2 4 3 1 1 4 4 4 4 2 4 4 4 4 3 2 4
4 3 4 4 4 4 4 1 3 4 1 3 1 3 4 4 4 4 4 3 4 3 4 3 1 4 1 4 3 4 4 1 3 3 4 4 3
1 3 3 3]

|