Reputation: 5
I need to implement the Discriminant Algorithm for Reduced Multidimensional Features proposed in the paper https://www.jstage.jst.go.jp/article/nolta/1/1/1_1_37/_pdf/-char/en as Algorithm 2. I am using Tensorly, Scipy and Numpy for the implementation of this algorithm. I have tried implementing the algorithm, but it fails to converge. Additionaly, the projection matrix psi (and subsequently the F_matrix) contain large Complex values.
def algoritam2(G, I, F, max_iter,tol):
"""
Discriminant Algorithm for Reduced Multidimensional Features
Parameters:
G : numpy.ndarray
Core feature tensor of dimension J1 x J2 x ... x JN x K
I : numpy.ndarray
Target labels for core tensor G (K features in C classes)
F : int
Number of discriminatory features we want
max_iter: int
Maximum number of iterations when maximizing the trace ratio
tol: float
Tolerance for dynamic criteria for convergence
Return values:
Psi : numpy.ndarray
Discrimination projection matrix, dim L x F
F_matrix : numpy.ndarray
Matrix of discriminatory features, dim F x K
note: L = J1 * J2 * ... * JN
"""
K = G.shape[-1]
N = G.ndim - 1
klase = np.unique(I)
G_mean = np.mean(G, axis=-1)
G_c_mean=np.zeros(G.shape[:-1]+(len(klase),))
G_tilda=np.zeros(G.shape)
for c in klase:
ind=np.where(I==c)[0]
G_c = G[..., ind] # Uzori iz klase c
K_c = G_c.shape[-1] # Broj uzoraka u klasi c
G_c_mean[...,c] = np.mean(G_c, axis=-1) # Prosjek za klasu c
for i in ind:
G_tilda[...,i]=G[...,i]-G_c_mean[...,c] #Centralizacija klase c
G_c_mean[...,c]=np.sqrt(K_c)*(G_c_mean[...,c]-G_mean)
S_w = tl.base.unfold(G_tilda,N).T @ tl.base.unfold(G_tilda,N)
S_b = tl.base.unfold(G_c_mean,N).T @ tl.base.unfold(G_c_mean,N)
delta, psi = eigs(S_b,k=F,M=S_w,which='LM')
br_iter=0
uvjet=1
while br_iter<max_iter and uvjet:
psi_stari=psi
br_iter+=1
fi=np.trace(psi.T @ S_b @ psi)/np.trace(psi.T @ S_w @ psi)
delta, psi = eigs(S_b-fi*S_w,k=F,which='LM')
delta, psi = eigs([email protected]@S_w@[email protected],k=F,which='LM')
uvjet=(np.linalg.norm(psi-psi_stari)>=tol)
F_matrix = psi.T @ tl.unfold(G, N).T
return psi, F_matrix
Can you help me find the bug?
Upvotes: -2
Views: 30