Clear Filters
Clear Filters

Mean shift clustering - issue with finding the center of my clusters

13 views (last 30 days)
Hi all, as you can see from the attached image, I cannot detect the center of my dots (in blu) by using the mean shift clustering. I will report the code below and I want to point out that I got the same result also chaining the bandwidht with any kind of number. Thanks a lot for helping me.
my code:
%%
% Import the data
% Prompt the user to choose a file
[filename, filepath] = uigetfile('*.txt', 'Select a text file');
file_name = filename;
remove = '.txt';
file_name_clean = strrep(file_name, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% Import data from text file
opts = delimitedTextImportOptions("NumVariables", 28);
opts.DataLines = [2, Inf];
opts.Delimiter = "\t";
opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
opts.VariableTypes = ["string", "double", "double", "double", "double"];
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, file_name);
data = readmatrix(file_path, opts);
% Perform Mean Shift clustering
bandwidth = 50; % bandwidth parameter for Mean Shift
[cluster_centers, data2cluster, cluster2dataCell] = MeanShiftCluster(data, bandwidth);
% Plotting the data with logarithmic x-axis and error bars for averages and standard deviations
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 10, 'DisplayName', 'XY coordinates');
hold on;
% Set x-axis limit starting from 0
xlim([0, max(data(:,2))]);
% Set y-axis limit starting from 0
ylim([0, max(data(:,1))]);
% Plot cluster centers
hold on;
plot(cluster_centers(:,2), cluster_centers(:,1), 'kx', 'MarkerSize', 15, 'LineWidth', 3, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');
  2 Comments
MARCO
MARCO on 7 May 2024
here you have the x y coordinate of the blu dots. The original dataset is too big to be shared here.
MARCO
MARCO on 7 May 2024
function [clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag);
%perform MeanShift Clustering of data using a flat kernel
%
% ---INPUT---
% dataPts - input data, (numDim x numPts)
% bandWidth - is bandwidth parameter (scalar)
% plotFlag - display output if 2 or 3 D (logical)
% ---OUTPUT---
% clustCent - is locations of cluster centers (numDim x numClust)
% data2cluster - for every data point which cluster it belongs to (numPts)
% cluster2dataCell - for every cluster which points are in it (numClust)
%
% Bryan Feldman 02/24/06
% MeanShift first appears in
% K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a
% Density Function, with Applications in Pattern Recognition"
%*** Check input ****
if nargin < 2
error('no bandwidth specified')
end
if nargin < 3
plotFlag = true;
plotFlag = false;
end
%**** Initialize stuff ***
[numDim,numPts] = size(dataPts);
numClust = 0;
bandSq = bandWidth^2;
initPtInds = 1:numPts;
maxPos = max(dataPts,[],2); %biggest size in each dimension
minPos = min(dataPts,[],2); %smallest size in each dimension
boundBox = maxPos-minPos; %bounding box size
sizeSpace = norm(boundBox); %indicator of size of data space
stopThresh = 1e-3*bandWidth; %when mean has converged
clustCent = []; %center of clust
beenVisitedFlag = zeros(1,numPts,'uint8'); %track if a points been seen already
numInitPts = numPts; %number of points to posibaly use as initilization points
clusterVotes = zeros(1,numPts,'uint16'); %used to resolve conflicts on cluster membership
while numInitPts
tempInd = ceil( (numInitPts-1e-6)*rand); %pick a random seed point
stInd = initPtInds(tempInd); %use this point as start of mean
myMean = dataPts(:,stInd); % intilize mean to this points location
myMembers = []; % points that will get added to this cluster
thisClusterVotes = zeros(1,numPts,'uint16'); %used to resolve conflicts on cluster membership
while 1 %loop untill convergence
sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); %dist squared from mean to all points still active
inInds = find(sqDistToAll < bandSq); %points within bandWidth
thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; %add a vote for all the in points belonging to this cluster
myOldMean = myMean; %save the old mean
myMean = mean(dataPts(:,inInds),2); %compute the new mean
myMembers = [myMembers inInds]; %add any point within bandWidth to the cluster
beenVisitedFlag(myMembers) = 1; %mark that these points have been visited
%*** plot stuff ****
if plotFlag
figure(12345),clf,hold on
if numDim == 2
plot(dataPts(1,:),dataPts(2,:),'.')
plot(dataPts(1,myMembers),dataPts(2,myMembers),'ys')
plot(myMean(1),myMean(2),'go')
plot(myOldMean(1),myOldMean(2),'rd')
pause
end
end
%**** if mean doesn't move much stop this cluster ***
if norm(myMean-myOldMean) < stopThresh
%check for merge posibilities
mergeWith = 0;
for cN = 1:numClust
distToOther = norm(myMean-clustCent(:,cN)); %distance from posible new clust max to old clust max
if distToOther < bandWidth/2 %if its within bandwidth/2 merge new and old
mergeWith = cN;
break;
end
end
if mergeWith > 0 % something to merge
clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %record the max as the mean of the two merged (I know biased twoards new ones)
%clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); %record which points inside
clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; %add these votes to the merged cluster
else %its a new cluster
numClust = numClust+1; %increment clusters
clustCent(:,numClust) = myMean; %record the mean
%clustMembsCell{numClust} = myMembers; %store my members
clusterVotes(numClust,:) = thisClusterVotes;
end
break;
end
end
initPtInds = find(beenVisitedFlag == 0); %we can initialize with any of the points not yet visited
numInitPts = length(initPtInds); %number of active points in set
end
[val,data2cluster] = max(clusterVotes,[],1); %a point belongs to the cluster with the most votes
%*** If they want the cluster2data cell find it for them
if nargout > 2
cluster2dataCell = cell(numClust,1);
for cN = 1:numClust
myMembers = find(data2cluster == cN);
cluster2dataCell{cN} = myMembers;
end
end

Sign in to comment.

Answers (3)

Mathieu NOE
Mathieu NOE on 13 May 2024
hello Marco
seems that your issue is simply because the function works for row oriented data
see those lines in MeanShiftCluster.m
%**** Initialize stuff ***
[numDim,numPts] = size(dataPts);
so, with your provided data file, I needed to transpose the data array
[cluster_centers, data2cluster, cluster2dataCell] = MeanShiftCluster(data', bandwidth); % NB : data' (transposed)
full code :
%%
clc
clearvars
close all
% Import the data
% Prompt the user to choose a file
[filename, filepath] = uigetfile('*.txt', 'Select a text file');
file_name = filename;
remove = '.txt';
file_name_clean = strrep(file_name, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% Import data from text file
opts = delimitedTextImportOptions("NumVariables", 28);
opts.DataLines = [2, Inf];
opts.Delimiter = "\t";
opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
opts.VariableTypes = ["string", "double", "double", "double", "double"];
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, file_name);
% data = readmatrix(file_path, opts);
data = readmatrix(file_path); % <= works better in this case without opts
% Perform Mean Shift clustering
bandwidth = 50; % bandwidth parameter for Mean Shift
[cluster_centers, data2cluster, cluster2dataCell] = MeanShiftCluster(data', bandwidth); % NB : data' (transposed)
% Plotting the data with logarithmic x-axis and error bars for averages and standard deviations
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 15, 'DisplayName', 'XY coordinates');
hold on;
% Set x-axis limit starting from 0
xlim([0, max(data(:,2))]);
% Set y-axis limit starting from 0
ylim([0, max(data(:,1))]);
% Plot cluster centers
hold on;
% plot(cluster_centers(:,2), cluster_centers(:,1), 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
plot(cluster_centers(2,:), cluster_centers(1,:), 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');
  8 Comments
MARCO
MARCO on 16 May 2024
Edited: MARCO on 16 May 2024
thank you for your help! I have to find a way to make the dataset lighter otherwise I cannot analyse it.
do you suggest to open a .tif image and analyse it? How can I do so?
Mathieu NOE
Mathieu NOE on 16 May 2024
I have to say I'm not an expert in image processing (and I don't have the required toolbox either), but there are many answers on this forum about how to detect circles or blobs in images and find their centers
and probably dozens more examples if you search in the FEX

Sign in to comment.


Mathieu NOE
Mathieu NOE on 16 May 2024
Now probably my best contribution so far , and I post it here with maybe the hope that you will find it interesting enough to accept it ! :)
so I followed my idea to split the data in smaller chuncks , => splitting along the x axis only and repeating the process in each x window . then concatenate the cluster centers results ;
there is something I noticed though, is that you may have some duplicates at the junction between two data batches , so the trick here was to apply the same process once again on the cluster centers concatenation result, and this way you get the "unique" centers.
I also tried with different split factor (x_inter in the code below) , to see when we achieve the best performance between the clsutering process and the time to concatenate the results - there is a optimum to find :
the result on your data file are :
x_inter = 10; Elapsed time is 5.850134 seconds.
x_inter = 50; Elapsed time is 2.427936 seconds.
x_inter = 100; Elapsed time is 2.262699 seconds.
x_inter = 200; Elapsed time is 2.621037 seconds.
x_inter = 500; Elapsed time is 5.565575 seconds.
here the code :
%%
clc
clearvars
close all
% Import the data
% Prompt the user to choose a file
% [filename, filepath] = uigetfile('*.txt', 'Select a text file');
filepath = pwd;
filename = 'selected_dataset.txt';
remove = '.txt';
file_name_clean = strrep(filename, remove, '');
%%
% Plotting
plot_name = ['Intensity_' file_name_clean '.svg'];
% Import data from text file
opts = delimitedTextImportOptions("NumVariables", 28);
opts.DataLines = [2, Inf];
opts.Delimiter = "\t";
opts.VariableNames = ["channel_name", "x", "y", "x_c", "y_c"];
opts.SelectedVariableNames = ["x", "y"]; % Only select the x and y columns
opts.VariableTypes = ["string", "double", "double", "double", "double"];
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% Construct the full file path
file_path = fullfile(filepath, filename);
% data = readmatrix(file_path, opts);
data = readmatrix(file_path);
%% Split the big data set in smaller chunks
x_inter = 100; % split the data along x intervals
minx = min(data(:,2));
maxx = max(data(:,2));
dx = (maxx - minx)/x_inter;
cx_all = [];
cy_all = [];
% Perform Mean Shift clustering
bandwidth = 50; % bandwidth parameter for Mean Shift
tic
for ck = 1:x_inter
xmin = minx+(ck-1)*dx;
xmax = xmin+dx;
ind = (data(:,2)>=xmin) & (data(:,2)<xmax);
data_batch = data(ind,:);
if ~isempty(data_batch) % if you split by too much, data_batch may be empty - so check it !
% Perform Mean Shift clustering
[cluster_centers, ~, ~] = MeanShiftCluster(data_batch', bandwidth); % NB : data_batch' (transposed) (row oriented array)
cx = cluster_centers(2,:);
cy = cluster_centers(1,:);
cx_all = [cx_all cx];
cy_all = [cy_all cy];
end
end
% as they may be some redondant cluster centers due to the data splitting
% process, we repeat the MeanShiftCluster process once more on the result
[cluster_centers, ~, ~] = MeanShiftCluster([cx_all;cy_all], bandwidth);
cx = cluster_centers(1,:);
cy = cluster_centers(2,:);
toc
% Plotting the data with logarithmic x-axis and error bars for averages and standard deviations
figure;
plot(data(:,2), data(:,1), '.', 'MarkerSize', 15, 'DisplayName', 'XY coordinates');
hold on;
% Set x-axis limit starting from 0
xlim([0, max(data(:,2))]);
% Set y-axis limit starting from 0
ylim([0, max(data(:,1))]);
% Plot cluster centers
hold on;
% plot(cluster_centers(2,:), cluster_centers(1,:), 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
plot(cx, cy, 'kx', 'MarkerSize', 15, 'DisplayName', 'Cluster Centers');
hold off;
xlabel('X');
ylabel('Y');
title('Mean Shift Clustering');
legend('XY coordinates', 'Cluster Centers');
  7 Comments
MARCO
MARCO on 21 May 2024
@Mathieu NOE Now I figured out what I was doing wrong. I was loading the full .txt file and not only the x and y values. Indeed, now it works perfectly! Thanks a lot!
For curiosity, dbscan is much way faster on python than matlab.
Now, I have another question for both you @Mathieu NOE and @Image Analyst. If needed, I will open another proper question.
I have some 3D images where I would like to analyse the dimension of those red structures. Do you have any suggestion about what to use?

Sign in to comment.


Image Analyst
Image Analyst on 21 May 2024
How did you read in selected_dataset.rtf? Readmatrix() does not like that extension.
I don't think dbscan should take a long time. I'm attaching a demo of it. It should work for random (x,y) locations but if you have data in a regular grid, such that the locations can be considered pixels on an image, then you can use image analysis to find things like centroids, areas, diameters, etc.

Categories

Find more on Images in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!