import numpy as np


def clustering_entropy(data, yhat, class_threshold=0):
    '''Total Entropy of a Cluster
    
    Entropy measures the amount of disorder in the data.
    The larger the disorder, the larger the entropy value.
    The smallest possible value is 0, which appears then all vector elements are the same.
    '''
    row, col, dim = data.shape
    classes_num = dim + 1 # we have to add the background as additional class here
    n = len(yhat)*7 # total points
    labels = np.unique(yhat) # cluster labels
    cluster_num = len(labels) # number of clusters

    n_w = [] # total points in cluster w
    for label in labels:
        n_w.append(list(yhat).count(label)*7)

    w_c = [] # number of points classified as c in w  
    for cluster in range(cluster_num):
        w_c.append([])
        for _ in range(classes_num):
            w_c[cluster].append([])

    labels_for_class = [] # all cluster labels for a class
    for _ in range(classes_num):
        labels_for_class.append([])

    # apply the threshold for every class
    threshold = class_threshold
    val = []
    for d in range(dim):
        val.append(data[:, :, d] > threshold)
    
    # compare true labels with cluster labels
    yhat_reshaped = yhat.reshape((256,256))
    for d in range(dim):
        for r in range(row):
            for c in range(col):

                if val[d][r, c]:
                    labels_for_class[d].append(yhat_reshaped[r, c])
                if not val[d][r, c]:
                    labels_for_class[classes_num-1].append(yhat_reshaped[r, c])

    for cluster in range(cluster_num):
        for class_c in range(classes_num):
            w_c[cluster][class_c] = labels_for_class[class_c].count(labels[cluster])

    P_w_c = []
    for cluster in range(cluster_num):

        P_c = []
        for class_c in range(classes_num):
            P_c.append(abs(w_c[cluster][class_c])/n_w[cluster])

        P_w_c.append(P_c)

    H_w = []
    for cluster in range(cluster_num):
        H = []
        for class_c in range(classes_num):
            temp = P_w_c[cluster][class_c]*np.log2((P_w_c[cluster][class_c])+1e-6)
            if temp == temp:
                H.append(temp)
        H_w.append(-sum(H))

    return sum(H_w*(np.asarray(n_w)/n))
