Mean shift clustering - issue with finding the center of my clusters

Hi all, as you can see from the attached image, I cannot detect the center of my dots (in blu) by using the mean shift clustering. I will report the code below and I want to point out that I got the same result also chaining the bandwidht with any kind of number. Thanks a lot for helping me.
my code:
%%
% Import the data
% Prompt the user to choose a file
[filename, filepath] = uigetfile('*.txt', 'Select a text file');
file_name = filename;
remove = '.txt';
file_name_clean = strrep(file_name, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% Import data from text file
opts = delimitedTextImportOptions("NumVariables", 28);
opts.DataLines = [2, Inf];
opts.Delimiter = "\t";
opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
opts.VariableTypes = ["string", "double", "double", "double", "double"];
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, file_name);
data = readmatrix(file_path, opts);
% Perform Mean Shift clustering
bandwidth = 50; % bandwidth parameter for Mean Shift
[cluster_centers, data2cluster, cluster2dataCell] = MeanShiftCluster(data, bandwidth);
% Plotting the data with logarithmic x-axis and error bars for averages and standard deviations
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 10, 'DisplayName', 'XY coordinates');
hold on;
% Set x-axis limit starting from 0
xlim([0, max(data(:,2))]);
% Set y-axis limit starting from 0
ylim([0, max(data(:,1))]);
% Plot cluster centers
hold on;
plot(cluster_centers(:,2), cluster_centers(:,1), 'kx', 'MarkerSize', 15, 'LineWidth', 3, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');

2 Comments

here you have the x y coordinate of the blu dots. The original dataset is too big to be shared here.
function [clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag);
%perform MeanShift Clustering of data using a flat kernel
%
% ---INPUT---
% dataPts - input data, (numDim x numPts)
% bandWidth - is bandwidth parameter (scalar)
% plotFlag - display output if 2 or 3 D (logical)
% ---OUTPUT---
% clustCent - is locations of cluster centers (numDim x numClust)
% data2cluster - for every data point which cluster it belongs to (numPts)
% cluster2dataCell - for every cluster which points are in it (numClust)
%
% Bryan Feldman 02/24/06
% MeanShift first appears in
% K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a
% Density Function, with Applications in Pattern Recognition"
%*** Check input ****
if nargin < 2
error('no bandwidth specified')
end
if nargin < 3
plotFlag = true;
plotFlag = false;
end
%**** Initialize stuff ***
[numDim,numPts] = size(dataPts);
numClust = 0;
bandSq = bandWidth^2;
initPtInds = 1:numPts;
maxPos = max(dataPts,[],2); %biggest size in each dimension
minPos = min(dataPts,[],2); %smallest size in each dimension
boundBox = maxPos-minPos; %bounding box size
sizeSpace = norm(boundBox); %indicator of size of data space
stopThresh = 1e-3*bandWidth; %when mean has converged
clustCent = []; %center of clust
beenVisitedFlag = zeros(1,numPts,'uint8'); %track if a points been seen already
numInitPts = numPts; %number of points to posibaly use as initilization points
clusterVotes = zeros(1,numPts,'uint16'); %used to resolve conflicts on cluster membership
while numInitPts
tempInd = ceil( (numInitPts-1e-6)*rand); %pick a random seed point
stInd = initPtInds(tempInd); %use this point as start of mean
myMean = dataPts(:,stInd); % intilize mean to this points location
myMembers = []; % points that will get added to this cluster
thisClusterVotes = zeros(1,numPts,'uint16'); %used to resolve conflicts on cluster membership
while 1 %loop untill convergence
sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); %dist squared from mean to all points still active
inInds = find(sqDistToAll < bandSq); %points within bandWidth
thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; %add a vote for all the in points belonging to this cluster
myOldMean = myMean; %save the old mean
myMean = mean(dataPts(:,inInds),2); %compute the new mean
myMembers = [myMembers inInds]; %add any point within bandWidth to the cluster
beenVisitedFlag(myMembers) = 1; %mark that these points have been visited
%*** plot stuff ****
if plotFlag
figure(12345),clf,hold on
if numDim == 2
plot(dataPts(1,:),dataPts(2,:),'.')
plot(dataPts(1,myMembers),dataPts(2,myMembers),'ys')
plot(myMean(1),myMean(2),'go')
plot(myOldMean(1),myOldMean(2),'rd')
pause
end
end
%**** if mean doesn't move much stop this cluster ***
if norm(myMean-myOldMean) < stopThresh
%check for merge posibilities
mergeWith = 0;
for cN = 1:numClust
distToOther = norm(myMean-clustCent(:,cN)); %distance from posible new clust max to old clust max
if distToOther < bandWidth/2 %if its within bandwidth/2 merge new and old
mergeWith = cN;
break;
end
end
if mergeWith > 0 % something to merge
clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %record the max as the mean of the two merged (I know biased twoards new ones)
%clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); %record which points inside
clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; %add these votes to the merged cluster
else %its a new cluster
numClust = numClust+1; %increment clusters
clustCent(:,numClust) = myMean; %record the mean
%clustMembsCell{numClust} = myMembers; %store my members
clusterVotes(numClust,:) = thisClusterVotes;
end
break;
end
end
initPtInds = find(beenVisitedFlag == 0); %we can initialize with any of the points not yet visited
numInitPts = length(initPtInds); %number of active points in set
end
[val,data2cluster] = max(clusterVotes,[],1); %a point belongs to the cluster with the most votes
%*** If they want the cluster2data cell find it for them
if nargout > 2
cluster2dataCell = cell(numClust,1);
for cN = 1:numClust
myMembers = find(data2cluster == cN);
cluster2dataCell{cN} = myMembers;
end
end

Sign in to comment.

Answers (3)

hello Marco
seems that your issue is simply because the function works for row oriented data
see those lines in MeanShiftCluster.m
%**** Initialize stuff ***
[numDim,numPts] = size(dataPts);
so, with your provided data file, I needed to transpose the data array
[cluster_centers, data2cluster, cluster2dataCell] = MeanShiftCluster(data', bandwidth); % NB : data' (transposed)
full code :
%%
clc
clearvars
close all
% Import the data
% Prompt the user to choose a file
[filename, filepath] = uigetfile('*.txt', 'Select a text file');
file_name = filename;
remove = '.txt';
file_name_clean = strrep(file_name, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% Import data from text file
opts = delimitedTextImportOptions("NumVariables", 28);
opts.DataLines = [2, Inf];
opts.Delimiter = "\t";
opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
opts.VariableTypes = ["string", "double", "double", "double", "double"];
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, file_name);
% data = readmatrix(file_path, opts);
data = readmatrix(file_path); % <= works better in this case without opts
% Perform Mean Shift clustering
bandwidth = 50; % bandwidth parameter for Mean Shift
[cluster_centers, data2cluster, cluster2dataCell] = MeanShiftCluster(data', bandwidth); % NB : data' (transposed)
% Plotting the data with logarithmic x-axis and error bars for averages and standard deviations
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 15, 'DisplayName', 'XY coordinates');
hold on;
% Set x-axis limit starting from 0
xlim([0, max(data(:,2))]);
% Set y-axis limit starting from 0
ylim([0, max(data(:,1))]);
% Plot cluster centers
hold on;
% plot(cluster_centers(:,2), cluster_centers(:,1), 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
plot(cluster_centers(2,:), cluster_centers(1,:), 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');

8 Comments

Hi Mathieu, I am sorry for not answering earlier and I am really greatfull of your help.
by the way, when I run the code, it stacks in the analysis. I am not sure whether my laptop (macbook pro m2) is not enough to elaborate this data or there is still an issue in the code.
hello Marco
for the reduced dataset you provided here ( 102499 points) my PC needs 2 minutes to give the result.
and my PC is by far not super fast (Intel Icore5)
the time required to work with larger data sets may be prohibitive , so maybe you should split the data in smaller batches then combine the results - here my results :
1,000 points : Elapsed time is 0.092944 seconds.
10,000 points : Elapsed time is 2.526526 seconds.
100,000 points : Elapsed time is 95.893421 seconds.
I wonder if it would go much faster , probably with less accuracy to simply plot the data , save it as an image, then do some image processing (finding the center of the dots groups)
both solutions may be interesting, I will try on a small dataset as my PC did not give me anyting after 20 minutes of elaboration (as you can see it stops in the clustering analysis.
by the way, is it possible that the macbook pro m2 is so slow?
I don't think your PC is to blame (you can launch the bench command to test it ) but the task given is probably very CPU / memory intensive.
thank you for your help! I have to find a way to make the dataset lighter otherwise I cannot analyse it.
do you suggest to open a .tif image and analyse it? How can I do so?
I have to say I'm not an expert in image processing (and I don't have the required toolbox either), but there are many answers on this forum about how to detect circles or blobs in images and find their centers
and probably dozens more examples if you search in the FEX

Sign in to comment.

Now probably my best contribution so far , and I post it here with maybe the hope that you will find it interesting enough to accept it ! :)
so I followed my idea to split the data in smaller chuncks , => splitting along the x axis only and repeating the process in each x window . then concatenate the cluster centers results ;
there is something I noticed though, is that you may have some duplicates at the junction between two data batches , so the trick here was to apply the same process once again on the cluster centers concatenation result, and this way you get the "unique" centers.
I also tried with different split factor (x_inter in the code below) , to see when we achieve the best performance between the clsutering process and the time to concatenate the results - there is a optimum to find :
the result on your data file are :
x_inter = 10; Elapsed time is 5.850134 seconds.
x_inter = 50; Elapsed time is 2.427936 seconds.
x_inter = 100; Elapsed time is 2.262699 seconds.
x_inter = 200; Elapsed time is 2.621037 seconds.
x_inter = 500; Elapsed time is 5.565575 seconds.
here the code :
%%
clc
clearvars
close all
% Import the data
% Prompt the user to choose a file
% [filename, filepath] = uigetfile('*.txt', 'Select a text file');
filepath = pwd;
filename = 'selected_dataset.txt';
remove = '.txt';
file_name_clean = strrep(filename, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% Import data from text file
opts = delimitedTextImportOptions("NumVariables", 28);
opts.DataLines = [2, Inf];
opts.Delimiter = "\t";
opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
opts.VariableTypes = ["string", "double", "double", "double", "double"];
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, filename);
% data = readmatrix(file_path, opts);
data = readmatrix(file_path);
%% Split the big data set in smaller chunks
x_inter = 100; % split the data along x intervals
minx = min(data(:,2));
maxx = max(data(:,2));
dx = (maxx - minx)/x_inter;
cx_all = [];
cy_all = [];
% Perform Mean Shift clustering
bandwidth = 50; % bandwidth parameter for Mean Shift
tic
for ck = 1:x_inter
xmin = minx+(ck-1)*dx;
xmax = xmin+dx;
ind = (data(:,2)>=xmin) & (data(:,2)<xmax);
data_batch = data(ind,:);
if ~isempty(data_batch) % if you split by too much, data_batch may be empty - so check it !
% Perform Mean Shift clustering
[cluster_centers, ~, ~] = MeanShiftCluster(data_batch', bandwidth); % NB : data_batch' (transposed) (row oriented array)
cx = cluster_centers(2,:);
cy = cluster_centers(1,:);
cx_all = [cx_all cx];
cy_all = [cy_all cy];
end
end
% as they may be some redondant cluster centers due to the data splitting
% process, we repeat the MeanShiftCluster process once more on the result
[cluster_centers, ~, ~] = MeanShiftCluster([cx_all;cy_all], bandwidth);
cx = cluster_centers(1,:);
cy = cluster_centers(2,:);
toc
% Plotting the data with logarithmic x-axis and error bars for averages and standard deviations
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 15, 'DisplayName', 'XY coordinates');
hold on;
% Set x-axis limit starting from 0
xlim([0, max(data(:,2))]);
% Set y-axis limit starting from 0
ylim([0, max(data(:,1))]);
% Plot cluster centers
hold on;
% plot(cluster_centers(2,:), cluster_centers(1,:), 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
plot(cx, cy, 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');

7 Comments

In the meantime I tried the image processing technique
the accuracy of the result depends a bit on your dots marker size and the image resolution
for the data file provided , I think I have the optimal settings _ and the result is obtained in less than 4 seconds
basically, once we have saved the first figure as an BW image , we look for the blob boudaries and taken the center of those points
try it !
%%
clc
clearvars
close all
% Import the data
% Prompt the user to choose a file
% [filename, filepath] = uigetfile('*.txt', 'Select a text file');
filepath = pwd;
filename = 'selected_dataset.txt';
remove = '.txt';
file_name_clean = strrep(filename, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% % Import data from text file
% opts = delimitedTextImportOptions("NumVariables", 28);
% opts.DataLines = [2, Inf];
% opts.Delimiter = "\t";
% opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
% opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
% opts.VariableTypes = ["string", "double", "double", "double", "double"];
% opts.ExtraColumnsRule = "ignore";
% opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, filename);
% data = readmatrix(file_path, opts);
data = readmatrix(file_path);
minx = min(data(:,2));
maxx = max(data(:,2));
miny = min(data(:,1));
maxy = max(data(:,1));
% create a plot with no borders
fh = figure('Menu','none','ToolBar','none');
ah = axes('Units','Normalized','Position',[0 0 1 1]);
plot(data(:,2), data(:,1), '.', 'MarkerSize', 3, 'DisplayName', 'XY coordinates');
xlim([minx maxx]);
ylim([miny maxy]);
grid off
axis off
% save to image file
fname = 'demo.bmp';
print(fname,'-dbmpmono','-r500');
% read again the file
tic
yourImage = imread(fname);
% yourImage = ~yourImage; % take the complementary
figure,
imagesc(yourImage)
hold on
% extract all isocline for a given level
level = 0.5;
[C,h] = contour(yourImage,level*[1 1]);
ind = find(C(1,:)==level); % index of beginning of each isocline data in C
[~,nc] = size(C);
ind = [ind nc+1]; % add end (+1)
for k = 1:numel(ind)-1
% contour line
xcl = C(1,ind(k)+1:ind(k+1)-1);
ycl = C(2,ind(k)+1:ind(k+1)-1);
plot(xcl,ycl,'*r', 'MarkerSize', 2);
% center of each contour line
xc(k) = mean(xcl);
yc(k) = mean(ycl);
end
toc
plot(xc,yc,'kx', 'MarkerSize', 5);
hold off
%% final plot
[m,n] = size(yourImage);
% reverse yc direction (because images are shown y direction upside down)
yc = m - yc;
% convert back xc and yc from image coordinates (pixels) to data coordinates
xc = xc/n*(maxx - minx)+ minx;
yc = yc/m*(maxy - miny)+ miny;
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 15, 'DisplayName', 'XY coordinates');
hold on;
plot(xc, yc, 'kx', 'MarkerSize', 5, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');
Hi guys, thanks a lot for your support! I will give a look to those code as soon as I can. I was very busy in the wet lab those days.
@MARCO Can you give the context some more? What are the blue dots and the black X's? I'm not really seeing any clusters of blue dots -- they look more or less randomly located all over the plot/image.
  1. Do you want to cluster the blue dots into some small number of clusters like 5 or 10? If so, why, because I'm not seeing any obvious clusters of blue dots?
  2. Or do you simply want to take the blue dots as an image (because you don't have their coordinates) and want to find the centroid of the each blue dot?
Why are you using mean shift clustering? I would think something like dbscan would be better. https://en.wikipedia.org/wiki/DBSCAN
I'm an expert in image processing and reasonably competent in cluster analysis/machine learning.
hi @Mathieu NOE, I am trying to run your code to analyse the image but I am facing some problems as well.
I had to remove the xlim and ylim as it reports this error: Limits must be a 2-element vector of increasing numeric values.
Then, the code was able to end but just empy figures appeared. in this case I tried to open the entire dataset as txt.
by the way, I move the code to python and actually I was able to make the mean shift clustering working on my laptop. I believe matlab is not well managed by apple cpu.
Hi @Image Analyst, I try to explain you better what I would like to do.
if you check the attached screenshot, I show you how one blue dot looks like if you zoom into it. Particularly, I am analysing images made by STochastic Optical Reconstruction Microscopy (STORM). My idea was to find the center of these dots and then analyse some parameters of every center produces by the clusters (such as number of dots composing the cluster, their area, etc.). I suppose that if the protein that i am studing aggregates, I should see changes in the parameters of the clusters.
if i remember well, I tried to use dbscan but to analyse my dataset it requires a large amount of memory that I do not have on my laptop.
any further help is accepted!!
thanks a lot!
hello @MARCO
I suspect you had problem with my code because you tried to read your rtf file (matlab cannot read it)
I copied the data from your rtf file into a txt file (attached)
it should work now with the provided file
fyi, I also tried with dbscan but on my laptop it was taking forever to complete, that's why I looked for faster methods.
@Mathieu NOE Now I figured out what I was doing wrong. I was loading the full .txt file and not only the x and y values. Indeed, now it works perfectly! Thanks a lot!
For curiosity, dbscan is much way faster on python than matlab.
Now, I have another question for both you @Mathieu NOE and @Image Analyst. If needed, I will open another proper question.
I have some 3D images where I would like to analyse the dimension of those red structures. Do you have any suggestion about what to use?

Sign in to comment.

How did you read in selected_dataset.rtf? Readmatrix() does not like that extension.
I don't think dbscan should take a long time. I'm attaching a demo of it. It should work for random (x,y) locations but if you have data in a regular grid, such that the locations can be considered pixels on an image, then you can use image analysis to find things like centroids, areas, diameters, etc.

Categories

Asked:

on 7 May 2024

Commented:

on 21 May 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!