
How does Kruskal’s Algorithm progress?
As a continuation to the Prim’s Algorithmanimation, I have also implemented the Kruskal’s Algorithm as it is applied on randomly distributed numbers. Basically, instead of starting from a origin point, the Kruskal’s Algorithm starts by finding the distances between each point and then sort those distances in increasing order. Then, starting from the smallest distances, point pairs are added into disjoint sets to create trees until independent trees merge into a single tree (hence single cluster). One important point in the algorithm is that points that would create cycles within the trees are avoided so that minimum distance is guaranteed. Below are some end results of trees formed for given set of randomly distributed numbers:
Also, with a very small twist, the very same algorithm can be utilized to get clusters in such a way that any two points within a cluster is closer to each other while the any two objects from different clusters are further apart from each other. For that end, instead of merging all the points into a single tree, the iteration should continue until k trees are obtained where k is the number clusters desired. As an example, below is the final result to obtain cluster of 4 for the given points.

The code that implements the Kruskal’s algorithm in C++ is provided below. Note that it also has the capability to stop at desired number of clusters for the given points. Please, share your comments/questions below and stay tuned for the next post.
#include <algorithm>#include <iostream>#include <iomanip>#include <cassert>#include <vector>#include <set>#include <cmath>#include <fstream>#include<stdexcept>#include<climits>using std::vector;using std::pair;using namespace std;bool bVerbose = false;int offset = 1; // offset value used for priting to screen, e.g. points start from 1 not 0// Auxiliary functionsvoid dumpConnectedPointPairs(int point1, int point2,const string & filename=""){ fstream outFile; outFile.open(filename,ios::out | ios::app); if (outFile.is_open()){ outFile<<point1<<" "<<point2<<endl; }else { throw std::runtime_error("Could not open file"); } outFile.close();}void displayVectorofPair(vector< pair<double, pair<int,int>> > & dvp,const string & str=""){ size_t size= dvp.size(); if (bVerbose){ cout<<"size= "<<size<<endl; } cout<<str<<endl; cout<<"Point 1 : "; for (auto pr : dvp){ pair<int,int> point = pr.second; int firstpoint= point.first; cout<<firstpoint+offset<<' '; } cout<<"\nPoint 2 : "; for (auto pr : dvp){ pair<int,int> point = pr.second; int secondpoint= point.second; cout<<secondpoint+offset<<' '; } cout<<"\nDistance: "; for (auto pr : dvp){ double dist = pr.first; cout<<dist<<' '; } cout<<endl;}class DisjointSetsElement { public: int size, parent, rank; DisjointSetsElement(int size = 0, int parent = -1, int rank = 0): size(size), parent(parent), rank(rank) {}};class DisjointSets { public: int size; int max_table_size; vector <DisjointSetsElement> sets; explicit DisjointSets(int size): size(size), max_table_size(0), sets(size) { for (int i = 0; i < size; i++){ // makeSet(i) operation, i.e. create a singleton set sets[i].parent = i; //at first parent is assigned as itself (self-parenting [self-loop]) sets[i].rank = 0; // MY: Although already taken care of by the default value assignment. Done for the sake of clarity } } int getParent(int i) { // similar to Find(). Find parent and compress path. [This will later allow log*() if ( i != sets[i].parent) { //time as the tree depth does not increase much by compression] sets[i].parent = getParent(sets[i].parent ); } return sets[i].parent ; } void merge(int i, int j){ // merging or union of two sets (union_rank) - this makes the algorithm log*() complexity int i_id = getParent(i); int j_id = getParent(j); if (i_id == j_id) return; if ( sets[i_id].rank > sets[j_id].rank ){ sets[j_id].parent = i_id; } else { sets[i_id].parent = j_id; if (sets[i_id].rank == sets[j_id].rank){ sets[j_id].rank +=1; } } } // Print the disjoint set info for visualization/debugging void printSets(){ //print info on the sets cout<<"Vertex: "; for (int i=0; i<size; ++i){ cout<< i+offset <<' '; } cout<<"\nParent: "; for (auto &s : sets){ cout<< s.parent+offset <<' '; } cout<<"\nRank : "; for (auto &s : sets){ cout<<s.rank<<' '; } cout<<endl; } // This is used later on decide on whether the number of clusters required by the user is achieved or not int findNumberofUniqueParents(){ //Not a classical disjoint set or Kruskal algorithm component std::set<int> local_set; //since sets can store unique elements, it is a natural choice here for (auto &s : sets){ local_set.insert(s.parent); } if (bVerbose) { cout<<"Unique Number of Parents : "<<local_set.size()<<" - Unique parent vertices: "; for (auto & s : local_set){ cout<< s +offset<<' '; } cout<<endl; } return static_cast<int>(local_set.size()); }};bool compare ( pair<double, pair<int,int> > & lhs, pair<double, pair<int,int> > & rhs){ return lhs.first < rhs.first;}double clustering(vector<int> x, vector<int> y, int desiredNoClusters ) { // At the beginning, there are edges among all points, we make sure that we utilize edges only once (i.e. edge // between point 1 and point 2 is taken but not between Point 2 & Point 1. // Also self edges, i.e. Point i to Point i not included as it is zero size_t nVertex = x.size(); vector< pair<double,pair<int,int>> > distVector; // A vector of pair where each pair stores the distance (cost) // of each distinct edge and info of that edges on between which points it is as a pair if <int,int> // e.g. let distance between Point j and k be 1.3, then a pair of < 1.3, pair<j,k> > will be pushed to vector for (size_t i=0; i<nVertex; ++i){ for (size_t j=i+1; j<nVertex; ++j){ double dist = sqrt( (x[i]-x[j])*(x[i]-x[j]) + (y[i]-y[j])*(y[i]-y[j]) ); pair<int,int> point = make_pair(i,j); pair<double, pair<int,int> > distPoint = make_pair(dist,point); distVector.emplace_back(distPoint); } } if (bVerbose){ displayVectorofPair(distVector,"Dist Point Pair Vector:"); } // now sort the vector with respect to distance (in increasing distance order) std::sort(distVector.begin(), distVector.end(),compare); if (bVerbose){ displayVectorofPair(distVector,"Dist Point Pair Vector (Sorted):"); } // For all vertices (points), make singleton sets and display for visualization DisjointSets allsets(nVertex); // makeSet() operation is done here if (bVerbose) { cout<<"\nAt the beginning, the sets of points"<<endl; allsets.printSets(); } // delete the parent.txt file if it already exists std::remove("connectedpointpairs.txt"); size_t inext=0; for (size_t i=0; i<distVector.size(); ++i){ int point1 = (distVector[i].second).first; int point2 = (distVector[i].second).second; int parent1 = allsets.getParent(point1); int parent2 = allsets.getParent(point2); // Check whether they belong to the same parent. If so, that means we cannot merge // them since it would create a cycle which we don't want as we want minimum spanning tree bool sameParentsAlready = (parent1 == parent2); if (!sameParentsAlready ) { //do merge only if they are not connected already allsets.merge( point1, point2 ); // if they are not connected already (i.e. not have the same parents), merge them if (bVerbose) { allsets.printSets(); } // Dump the connection info between points after each merge (for visualization) dumpConnectedPointPairs(point1,point2,"connectedpointpairs.txt"); // This is used to stop merging when the desired number of clusters are reached int currentNoClusters = allsets.findNumberofUniqueParents(); if (currentNoClusters == desiredNoClusters ){ inext = i+1; break; } } } // Now that we have reached to the desired number of clusters, but within the clusters there can be point pairs // where the distances are still smaller than those between the clusters. But because they would create cycles // in the tree, we sweep them until we reach to the distance that is really between the clusters and not creating // a cycle. At that point, we stop and that distance is the distance we are looking for. double finalDist =0.0 ; for (size_t i=inext; i<distVector.size(); ++i){ int point1 = (distVector[i].second).first; int point2 = (distVector[i].second).second; int parent1 = allsets.getParent(point1); int parent2 = allsets.getParent(point2); bool sameParentsAlready = (parent1 == parent2); if (sameParentsAlready ) { allsets.merge( point1, point2 ); } else { finalDist = distVector[i].first; break; } } cout<<"Final minimum Distance between clusters = "<<finalDist<<endl; return finalDist;} int main(int argc, char** argv) { for (int i=0; i<argc;++i){ string str1=argv[i]; if (str1.compare("-verbose") == 0){ cout<<"Verbose option is requested"<<endl; bVerbose = true; } } size_t n; int k; std::cin >> n; vector<int> x(n), y(n); for (size_t i = 0; i < n; i++) { std::cin >> x[i] >> y[i]; } std::cin >> k; std::cout << std::setprecision(10) << clustering(x, y, k) << std::endl;}Once compiled and run, the above code will dump a text file (connectedpairs.txt) containing the step-by-step progression of the edge generation between the points. The input file format is as follows:
| Input file format | Sample input file |
Npointsx1 y1...xN yNNClusters | 31 24 52 31 |
The produced output file and the original input file can then be fed to the following Python code to visualize and animate that progression as shown in the above video using the following command:
./plotpathsparents.py -input testinput1 -connectedpoints connectedpointpairs.txt
#!<path to your python>/bin/python2.7'''Plots the points of a given set and the paths between themthat gives the minimum spanning tree among them'''import osfrom pylab import matplotlib, plt, sqrtdef FindFirstandLastCost(xpoints, ypoints, parentlines, nParents): ''' This can be used in the visualization of the cost function to find the max and min extents ''' str_parents = parentlines[1].split() # 0th is dummy parents = [int(i) for i in str_parents] cost_first = 0 for i in range(1, nParents): cost_first += sqrt((xpoints[i]-xpoints[parents[i]])**2 + (ypoints[i]-ypoints[parents[i]])**2) str_parents = parentlines[-1].split() parents = [int(i) for i in str_parents] cost_last = 0 for i in range(1, nParents): cost_last += sqrt((xpoints[i]-xpoints[parents[i]])**2 + (ypoints[i]-ypoints[parents[i]])**2) return (cost_first, cost_last)def PlotPointsOnly(ax, xpoints, ypoints, nPoints, k=0): ''' Here just the plotting of the points as small circles done. The connections between them (i.e. paths) are done elsewhere ''' for i in range(nPoints): if i == 0: plt.plot(xpoints[i], ypoints[i], 'mo', markersize=10) #ax.text(x[i]+2,y[i],'Origin',color='r',fontsize=14) #ax.text(x[i]+2,y[i],r'$P_{Origin}$',color='r',fontsize=15) #plt.figtext(0.1,0.1,'Origin',color='r',fontsize=10) else: plt.plot(xpoints[i], ypoints[i], 'mo', markersize=10) plt.axis('image') plt.grid('on') plt.xlabel('x') plt.ylabel('y') plt.title('Iteration # '+str(k)+' of '+str(nPoints), fontsize=16) if ARGS.verbose: print "min(x) = ", min(xpoints), " max(x)=", max(xpoints) print "min(y) = ", min(ypoints), " max(y)=", max(ypoints) axes = plt.gca() #axes.set_xlim( -200, 215 ) axes.set_xlim(min(xpoints)-5, max(xpoints)+5) axes.set_ylim(min(ypoints)-5, max(ypoints)+5) returndef GetPointsAndPaths(): ''' Read the input files and get necessary data for plotting ''' # Read the input files for processing with open(ARGS.input) as f: lines = f.read().splitlines() nPoints = int(lines[0]) if ARGS.verbose: print "lines=", lines print "nPoints = ", nPoints x, y = [[], []] for i in range(nPoints): temp = (lines[i+1]).split() #print "temp= ", temp x.append(float(temp[0])) y.append(float(temp[1])) print "x= ", x, " min(x) = ", min(x), " max(x) = ", max(x) print "y= ", y, " min(y) = ", min(y), " max(y) = ", max(y) with open(ARGS.connectedpoints) as f: connectedpointslines = f.read().splitlines() print "connectedpointslines = ", connectedpointslines nConnectedPoints = len(connectedpointslines) print "nConnectedPoints = ", nConnectedPoints return (x, y, connectedpointslines, nPoints, nConnectedPoints)# --------------------------------------------------def main(): x, y, connectedpointslines, nPoints, nConnectedPoints = GetPointsAndPaths() # For interactive plotting in python, check: http://stackoverflow.com/questions/11874767/real-time-plotting-in-while-loop-with-matplotlib plt.ion() # - Plot just the points (no connecting paths) fig2 = plt.figure(5, figsize=(16, 9)) ## This mean 16x9 inches (number of pixels is (16x9)*dpi value set in savefig fig2.patch.set_facecolor('white') ax = plt.subplot2grid((1, 3), (0, 0), colspan=2) PlotPointsOnly(ax, x, y, nPoints) plt.savefig('test_dpi240_16x9__initial.png', facecolor='w', dpi=240) #Find the first and last costs to set the y-limits of cost-vs-iteration plots #cost_first, cost_last= FindFirstandLastCost(x, y, connectedpointslines,nConnectedPoints) #if ARGS.verbose: print "cost_first, cost_last= ",cost_first," ", cost_last print "Start iterations:" cost_arr = [] for k in range(0, nConnectedPoints): #for k in range(1): #for k in range(nConnectedPoints-1,nConnectedPoints): print "k= ", k, " of ", nConnectedPoints str_connectedpoints = connectedpointslines[k].split() #print "str_connectedpoints = ", str_connectedpoints connectedpoints = [int(i) for i in str_connectedpoints] #print "connectedpoints= ", connectedpoints plt.cla() plt.clf() #ax = plt.subplot(1,2,1) ax = plt.subplot2grid((1, 3), (0, 0), colspan=2) PlotPointsOnly(ax, x, y, nPoints, k) # plot the paths cost_k = 0 for i in range(0, k+1): str_connectedpoints = connectedpointslines[i].split() connectedpoints = [int(j) for j in str_connectedpoints] m, n = connectedpoints plt.plot([x[m], x[n]], [y[m], y[n]], '-k', linewidth=2) cost_k += sqrt((x[m]-x[n])**2 + (y[m]-y[n])**2) cost_arr.append(cost_k) #print "cost_arr =", cost_arr print "cost_k =", cost_k #PlotPointsOnly(ax,x,y,nPoints,k) #plt.pause(0.5) # Plot the cost vs iterations ax = plt.subplot2grid((1, 3), (0, 2), colspan=1) plt.plot(cost_arr, '-mo', markersize=6) axes = plt.gca() axes.set_xlim(-1, k+2) axes.set_ylim(0, 700) plt.grid('on') plt.title(r'Cost: $\sum_{\forall\, i,j} (P_{i}-P_{j})_{connected}$ = ' +str(int(cost_k*10)/10.0), fontsize=15, y=1.02) plt.xlabel('Number of iterations', fontsize=14) plt.ylabel('Cost: Sum of all connected distances', fontsize=14) # Dump png file for later video processing plt.savefig('test_dpi240_16x9_'+str(k)+'.png', facecolor='w', dpi=240) plt.pause(0.1) # This is for interactive plotting plt.ion() while True: plt.pause(0.05)# -- Parse the input ---------------------------------------------------------def ParseInput(): ''' Read input arguments to be plotted by this script ''' import argparse parser = argparse.ArgumentParser() parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true") parser.add_argument("-input", type=str, default=None, help="enter original input") parser.add_argument("-connectedpoints", type=str, default=None, help="Enter cost and distance file") args = parser.parse_args() if not args.input or not args.connectedpoints: print "Enter cost and input files" exit(1) if args.verbose: print "input=", args.input print "connectedpoints=", args.connectedpoints return args# -----------------------------------------------------------------------------# This is the standard boilerplate that calls the main() function.if __name__ == '__main__': ARGS = ParseInput() main()For a similar animation of how the Prim’s algorithm work, please checkthis post.
Keywords:
Алгоритм Краскала, Algoritmo de Kruskal, 크러스컬 알고리즘, 克鲁斯克尔演算法
2 thoughts on “How does Kruskal’s Algorithm progress?”
- Pingback:Kruskal's Algorithm - Your Cheer






