How does Kruskal’s Algorithm progress?

As a continuation to the Prim’s Algorithmanimation, I have also implemented the Kruskal’s Algorithm as it is applied on randomly distributed numbers. Basically, instead of starting from a origin point, the Kruskal’s Algorithm starts by finding the distances between each point and then sort those distances in increasing order. Then, starting from the smallest distances, point pairs are added into disjoint sets to create trees until independent trees merge into a single tree (hence single cluster). One important point in the algorithm is that points that would create cycles within the trees are avoided so that minimum distance is guaranteed. Below are some end results of trees formed for given set of randomly distributed numbers:

Another example of Minimum Spanning Tree
Generated minimum spanning tree
Minimum Spanning Tree over New York City (coarser)
Minimum Spanning Tree over New York City
Minimum Spanning Tree zoomed over Manhattan
Minimum Spanning Tree for San Francisco and vicinity
Minimum Spanning Tree over West of United States

Also, with a very small twist, the very same algorithm can be utilized to get clusters in such a way that any two points within a cluster is closer to each other while the any two objects from different clusters are further apart from each other. For that end, instead of merging all the points into a single tree, the iteration should continue until k trees are obtained where k is the number clusters desired. As an example, below is the final result to obtain cluster of 4 for the given points.

test_dpi240_16x9_35
Clustering via Kruskal’s algorithm

The code that implements the Kruskal’s algorithm in C++ is provided below. Note that it also has the capability to stop at desired number of clusters for the given points. Please, share your comments/questions below and stay tuned for the next post.

#include <algorithm>#include <iostream>#include <iomanip>#include <cassert>#include <vector>#include <set>#include <cmath>#include <fstream>#include<stdexcept>#include<climits>using std::vector;using std::pair;using namespace std;bool bVerbose = false;int offset = 1; // offset value used for priting to screen, e.g. points start from 1 not 0// Auxiliary functionsvoid dumpConnectedPointPairs(int point1, int point2,const string & filename=""){  fstream outFile;  outFile.open(filename,ios::out | ios::app);  if (outFile.is_open()){    outFile<<point1<<" "<<point2<<endl;  }else {    throw std::runtime_error("Could not open file");  }  outFile.close();}void displayVectorofPair(vector< pair<double, pair<int,int>> > & dvp,const string & str=""){  size_t size= dvp.size();  if (bVerbose){  cout<<"size= "<<size<<endl; }  cout<<str<<endl;  cout<<"Point 1 : ";  for (auto pr : dvp){    pair<int,int> point = pr.second;    int firstpoint= point.first;    cout<<firstpoint+offset<<' ';  }  cout<<"\nPoint 2 : ";  for (auto pr : dvp){    pair<int,int> point = pr.second;    int secondpoint= point.second;    cout<<secondpoint+offset<<' ';  }  cout<<"\nDistance: ";  for (auto pr : dvp){    double dist = pr.first;    cout<<dist<<' ';  }  cout<<endl;}class DisjointSetsElement {  public:    int size, parent, rank;    DisjointSetsElement(int size = 0, int parent = -1, int rank = 0):      size(size), parent(parent), rank(rank) {}};class DisjointSets { public:  int size;  int max_table_size;  vector <DisjointSetsElement> sets;  explicit DisjointSets(int size):  size(size), max_table_size(0), sets(size) {    for (int i = 0; i < size; i++){ // makeSet(i) operation, i.e. create a singleton set       sets[i].parent = i; //at first parent is assigned as itself (self-parenting [self-loop])       sets[i].rank   = 0; // MY: Although already taken care of by the default value assignment. Done for the sake of clarity     }   }   int getParent(int i) { // similar to Find(). Find parent and compress path. [This will later allow log*()     if ( i != sets[i].parent) {  //time as the tree depth does not increase much by compression]       sets[i].parent = getParent(sets[i].parent );     }     return sets[i].parent ;   }   void merge(int i, int j){ // merging or union of two sets (union_rank) - this makes the algorithm log*() complexity     int i_id = getParent(i);     int j_id = getParent(j);     if (i_id == j_id)       return;     if ( sets[i_id].rank > sets[j_id].rank ){      sets[j_id].parent = i_id;    } else {      sets[i_id].parent = j_id;      if (sets[i_id].rank == sets[j_id].rank){        sets[j_id].rank +=1;      }    }  }  // Print the disjoint set info for visualization/debugging  void printSets(){ //print info on the sets    cout<<"Vertex: ";    for (int i=0; i<size; ++i){   cout<< i+offset <<' '; }    cout<<"\nParent: ";    for (auto &s : sets){ cout<< s.parent+offset <<' '; }    cout<<"\nRank  : ";    for (auto &s : sets){ cout<<s.rank<<' '; }    cout<<endl;  }  // This is used later on decide on whether the number of clusters required by the user is achieved or not  int findNumberofUniqueParents(){ //Not a classical disjoint set or Kruskal algorithm component    std::set<int> local_set; //since sets can store unique elements, it is a natural choice here    for (auto &s : sets){      local_set.insert(s.parent);    }    if (bVerbose) {      cout<<"Unique Number of Parents : "<<local_set.size()<<"  - Unique parent vertices: ";      for (auto & s : local_set){        cout<< s +offset<<' ';      }      cout<<endl;    }    return static_cast<int>(local_set.size());  }};bool compare ( pair<double, pair<int,int> > & lhs, pair<double, pair<int,int> > & rhs){  return lhs.first < rhs.first;}double clustering(vector<int> x, vector<int> y, int desiredNoClusters ) {  // At the beginning, there are edges among all points, we make sure that we utilize edges only once (i.e. edge  // between point 1 and point 2 is taken but not between Point 2 & Point 1.  // Also self edges, i.e. Point i to Point i not included as it is zero  size_t nVertex = x.size();  vector<  pair<double,pair<int,int>>  > distVector; // A vector of pair where each pair stores the distance (cost)                               // of each distinct edge and info of that edges on between which points it is as a pair if <int,int>                               // e.g. let distance between Point j and k be 1.3, then a pair of < 1.3, pair<j,k> > will be pushed to vector  for (size_t i=0; i<nVertex; ++i){    for (size_t j=i+1; j<nVertex; ++j){      double dist = sqrt( (x[i]-x[j])*(x[i]-x[j]) + (y[i]-y[j])*(y[i]-y[j]) );      pair<int,int> point = make_pair(i,j);      pair<double, pair<int,int> > distPoint = make_pair(dist,point);      distVector.emplace_back(distPoint);    }  }  if (bVerbose){ displayVectorofPair(distVector,"Dist Point Pair Vector:");   }  // now sort the vector with respect to distance (in increasing distance order)  std::sort(distVector.begin(), distVector.end(),compare);   if (bVerbose){ displayVectorofPair(distVector,"Dist Point Pair Vector (Sorted):"); }  // For all vertices (points), make singleton sets and display for visualization  DisjointSets allsets(nVertex); // makeSet() operation is done here  if (bVerbose) {      cout<<"\nAt the beginning, the sets of points"<<endl;      allsets.printSets();   }  // delete the parent.txt file if it already exists  std::remove("connectedpointpairs.txt");  size_t inext=0;  for (size_t i=0; i<distVector.size(); ++i){    int point1 = (distVector[i].second).first;    int point2 = (distVector[i].second).second;    int parent1 = allsets.getParent(point1);    int parent2 = allsets.getParent(point2);    // Check whether they belong to the same parent. If so, that means we cannot merge    // them since it would create a cycle which we don't want as we want minimum spanning tree    bool sameParentsAlready = (parent1 == parent2);    if (!sameParentsAlready ) { //do merge only if they are not connected already      allsets.merge( point1, point2 ); // if they are not connected already (i.e. not have the same parents), merge them      if (bVerbose) { allsets.printSets(); }      // Dump the connection info between points after each merge (for visualization)      dumpConnectedPointPairs(point1,point2,"connectedpointpairs.txt");      // This is used to stop merging when the desired number of clusters are reached      int currentNoClusters = allsets.findNumberofUniqueParents();      if (currentNoClusters == desiredNoClusters ){        inext = i+1;        break;      }     }  }  // Now that we have reached to the desired number of clusters, but within the clusters there can be point pairs  // where the distances are still smaller than those between the clusters. But because they would create cycles  // in the tree, we sweep them until we reach to the distance that is really between the clusters and not creating  // a cycle. At that point, we stop and that distance is the distance we are looking for.  double finalDist =0.0 ;  for (size_t i=inext; i<distVector.size(); ++i){    int point1 = (distVector[i].second).first;    int point2 = (distVector[i].second).second;    int parent1 = allsets.getParent(point1);    int parent2 = allsets.getParent(point2);    bool sameParentsAlready = (parent1 == parent2);    if (sameParentsAlready ) {      allsets.merge( point1, point2 );     } else {      finalDist = distVector[i].first;      break;     }  }  cout<<"Final minimum Distance between clusters = "<<finalDist<<endl;  return finalDist;} int main(int argc, char** argv) {  for (int i=0; i<argc;++i){    string str1=argv[i];    if (str1.compare("-verbose") == 0){      cout<<"Verbose option is requested"<<endl;       bVerbose = true;     }   }   size_t n;   int k;   std::cin >> n;  vector<int> x(n), y(n);  for (size_t i = 0; i < n; i++) {     std::cin >> x[i] >> y[i];  }  std::cin >> k;  std::cout << std::setprecision(10) << clustering(x, y, k) << std::endl;}

Once compiled and run, the above code will dump a text file (connectedpairs.txt) containing the step-by-step progression of the edge generation between the points. The input file format is as follows:

Input file formatSample input file
Npointsx1 y1...xN yNNClusters
31 24 52 31

The produced output file and the original input file can then be fed to the following Python code to visualize and animate that progression as shown in the above video using the following command:

./plotpathsparents.py -input testinput1 -connectedpoints connectedpointpairs.txt

#!&lt;path to your python&gt;/bin/python2.7'''Plots the points of a given set and the paths between themthat gives the minimum spanning tree among them'''import osfrom pylab import matplotlib, plt, sqrtdef FindFirstandLastCost(xpoints, ypoints, parentlines, nParents):  ''' This can be used in the visualization of the cost function to      find the max and min extents  '''  str_parents = parentlines[1].split() # 0th is dummy  parents = [int(i) for i in str_parents]  cost_first = 0  for i in range(1, nParents):    cost_first += sqrt((xpoints[i]-xpoints[parents[i]])**2  + (ypoints[i]-ypoints[parents[i]])**2)  str_parents = parentlines[-1].split()  parents = [int(i) for i in str_parents]  cost_last = 0  for i in range(1, nParents):    cost_last += sqrt((xpoints[i]-xpoints[parents[i]])**2  + (ypoints[i]-ypoints[parents[i]])**2)  return (cost_first, cost_last)def PlotPointsOnly(ax, xpoints, ypoints, nPoints, k=0):  '''    Here just the plotting of the points as small circles done.    The connections between them (i.e. paths) are done elsewhere  '''  for i in range(nPoints):    if i == 0:      plt.plot(xpoints[i], ypoints[i], 'mo', markersize=10)      #ax.text(x[i]+2,y[i],'Origin',color='r',fontsize=14)      #ax.text(x[i]+2,y[i],r'$P_{Origin}$',color='r',fontsize=15)      #plt.figtext(0.1,0.1,'Origin',color='r',fontsize=10)    else:      plt.plot(xpoints[i], ypoints[i], 'mo', markersize=10)  plt.axis('image')  plt.grid('on')  plt.xlabel('x')  plt.ylabel('y')  plt.title('Iteration # '+str(k)+' of '+str(nPoints), fontsize=16)  if ARGS.verbose:    print "min(x) = ", min(xpoints), " max(x)=", max(xpoints)    print "min(y) = ", min(ypoints), " max(y)=", max(ypoints)  axes = plt.gca()  #axes.set_xlim( -200, 215 )  axes.set_xlim(min(xpoints)-5, max(xpoints)+5)  axes.set_ylim(min(ypoints)-5, max(ypoints)+5)  returndef GetPointsAndPaths():  '''  Read the input files and get necessary data for plotting  '''  # Read the input files for processing  with open(ARGS.input) as f:    lines = f.read().splitlines()  nPoints = int(lines[0])  if ARGS.verbose:    print "lines=", lines    print "nPoints = ", nPoints   x, y = [[], []]  for i in range(nPoints):    temp = (lines[i+1]).split()    #print "temp= ", temp    x.append(float(temp[0]))    y.append(float(temp[1]))  print "x= ", x, " min(x) = ", min(x), " max(x) = ", max(x)  print "y= ", y, " min(y) = ", min(y), " max(y) = ", max(y)  with open(ARGS.connectedpoints) as f:    connectedpointslines = f.read().splitlines()  print "connectedpointslines  = ", connectedpointslines    nConnectedPoints = len(connectedpointslines)  print "nConnectedPoints = ", nConnectedPoints  return (x, y, connectedpointslines, nPoints, nConnectedPoints)# --------------------------------------------------def main():  x, y, connectedpointslines, nPoints, nConnectedPoints = GetPointsAndPaths()  # For interactive plotting in python, check: http://stackoverflow.com/questions/11874767/real-time-plotting-in-while-loop-with-matplotlib  plt.ion()  # - Plot just the points (no connecting paths)  fig2 = plt.figure(5, figsize=(16, 9)) ## This mean 16x9 inches (number of pixels is (16x9)*dpi value set in savefig  fig2.patch.set_facecolor('white')  ax = plt.subplot2grid((1, 3), (0, 0), colspan=2)  PlotPointsOnly(ax, x, y, nPoints)  plt.savefig('test_dpi240_16x9__initial.png', facecolor='w', dpi=240)  #Find the first and last costs to set the y-limits of cost-vs-iteration plots  #cost_first, cost_last= FindFirstandLastCost(x, y, connectedpointslines,nConnectedPoints)  #if ARGS.verbose: print "cost_first, cost_last= ",cost_first," ", cost_last  print "Start iterations:"  cost_arr = []  for k in range(0, nConnectedPoints):  #for k in range(1):  #for k in range(nConnectedPoints-1,nConnectedPoints):    print "k= ", k, " of ", nConnectedPoints    str_connectedpoints = connectedpointslines[k].split()    #print "str_connectedpoints  = ", str_connectedpoints    connectedpoints = [int(i) for i in str_connectedpoints]    #print "connectedpoints= ", connectedpoints    plt.cla()    plt.clf()    #ax = plt.subplot(1,2,1)    ax = plt.subplot2grid((1, 3), (0, 0), colspan=2)    PlotPointsOnly(ax, x, y, nPoints, k)    # plot the paths    cost_k = 0    for i in range(0, k+1):      str_connectedpoints = connectedpointslines[i].split()      connectedpoints = [int(j) for j in str_connectedpoints]      m, n = connectedpoints      plt.plot([x[m], x[n]], [y[m], y[n]], '-k', linewidth=2)      cost_k += sqrt((x[m]-x[n])**2  + (y[m]-y[n])**2)    cost_arr.append(cost_k)    #print "cost_arr =", cost_arr    print "cost_k =", cost_k    #PlotPointsOnly(ax,x,y,nPoints,k)    #plt.pause(0.5)    # Plot the cost vs iterations    ax = plt.subplot2grid((1, 3), (0, 2), colspan=1)    plt.plot(cost_arr, '-mo', markersize=6)    axes = plt.gca()    axes.set_xlim(-1, k+2)    axes.set_ylim(0, 700)    plt.grid('on')    plt.title(r'Cost: $\sum_{\forall\, i,j} (P_{i}-P_{j})_{connected}$ = ' +str(int(cost_k*10)/10.0), fontsize=15, y=1.02)    plt.xlabel('Number of iterations', fontsize=14)    plt.ylabel('Cost: Sum of all connected distances', fontsize=14)    # Dump png file for later video processing    plt.savefig('test_dpi240_16x9_'+str(k)+'.png', facecolor='w', dpi=240)    plt.pause(0.1)  # This is for interactive plotting plt.ion()  while True:    plt.pause(0.05)# -- Parse the input ---------------------------------------------------------def ParseInput():  '''  Read input arguments to be plotted by this script  '''  import argparse  parser = argparse.ArgumentParser()  parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true")  parser.add_argument("-input", type=str, default=None, help="enter original input")  parser.add_argument("-connectedpoints", type=str, default=None, help="Enter cost and distance file")  args = parser.parse_args()  if not args.input or not args.connectedpoints:    print "Enter cost and input files"    exit(1)  if args.verbose:    print "input=", args.input    print "connectedpoints=", args.connectedpoints  return args# -----------------------------------------------------------------------------# This is the standard boilerplate that calls the main() function.if __name__ == '__main__':  ARGS = ParseInput()  main()

For a similar animation of how the Prim’s algorithm work, please checkthis post.

Keywords:
Алгоритм Краскала, Algoritmo de Kruskal, 크러스컬 알고리즘, 克鲁斯克尔演算法

2 thoughts on “How does Kruskal’s Algorithm progress?

Leave a commentCancel reply