ByteMuse.com

The code and musings of @ChrisPolis

About Posts Contact

K-means Clustering and Visualization

April 22nd, 2014

One of the simplest machine learning algorithms that I know is K-means clustering. It is used to classify a data set into k groups with similar attributes and lets itself really well to visualization!

Here is a quick overview of the algorithm:

  • Pick or randomly select k group centroids
  • Group/bin points by nearest centroid
  • Recalculate centroids from points in corresponding bin
  • Continue binning/moving centroids until convergence

More information on K-means clustering


Note: Points are normally distributed around around k random points to create a somewhat grouped distribution. Click 'Step' to go through the algorithm.

n = k = Step New Problem

View Source Hide Source

# Generate x with normal dist function, 0 < x < 100
normalPt = (normalFun) ->
  val = normalFun()
  if val > 0 and val < 100 then val else normalPt(normalFun)

# Generate centers for dist centers and centroids
randomCenter = () -> (Math.random() * 90) + 5

# Euclidean distance
distance = (a, b) ->
  Math.sqrt(Math.pow(a[0]-b[0],2) + Math.pow(a[1]-b[1], 2))
avgXY = (arr) ->
  [ d3.sum(arr, (d) -> d[0]) / arr.length,
    d3.sum(arr, (d) -> d[1]) / arr.length ]

#
$(document).on 'page:load', () ->
  return unless $('#kmeans-vis').length

  #
  colors = d3.scale.category10()
  width  = $('.container').width()
  if width > 600 then width = 600
  height = width

  #
  svg = d3.select('#nmeans-vis').append('svg')
    .attr 'width', width
    .attr 'height', height
  pointsGroup    = svg.append('g').attr('id', 'points')
  centroidsGroup = svg.append('g').attr('id', 'centroids')
  voronoiGroup   = svg.append('g').attr('id', 'voronoi')
  x = d3.scale.linear()
    .range  [0, width]
    .domain [0, 100]
  y = d3.scale.linear()
    .range  [height, 0]
    .domain [0, 100]
  voronoi = d3.geom.voronoi()
    .x((d) -> x(d[0]))
    .y((d) -> y(d[1]))

  #
  window.initProblem = () ->
    window.points = []
    window.centroids = []
    window.k = parseInt $('#k-val').val()
    window.n = parseInt $('#n-val').val()

    # Generate psuedo random points normally distributed around k centers
    for kNdx in [1..k]
      xNorm  = d3.random.normal(randomCenter(), 12)
      yNorm  = d3.random.normal(randomCenter(), 12)
      for ptNdx in [1..(n/k)]
        points.push [normalPt(xNorm), normalPt(yNorm)]

    # Generate centroids
    for kNdx in [1..k]
      centroids.push [randomCenter(), randomCenter()]

    # Plot
    voronoiGroup.selectAll('*').remove()
    centroidsGroup.selectAll('*').remove()
    pointsGroup.selectAll('*').remove()
    centroidsGroup.selectAll('circle')
      .data(centroids).enter()
      .append('circle')
        .style 'fill', (d,ndx) -> colors(ndx)
        .attr  'cx', (d) -> x(d[0])
        .attr  'cy', (d) -> y(d[1])
        .attr  'r', 4.5
    pointsGroup.selectAll('circle')
      .data(points).enter()
      .append('circle')
        .attr 'cx', (d) -> x(d[0])
        .attr 'cy', (d) -> y(d[1])
        .attr 'r', 1.5

  #
  window.step = () ->
    # Render voronoi
    voronoiGroup.selectAll('*').remove()
    voronoiGroup.selectAll('path')
      .data(voronoi(centroids))
      .enter().append('path')
        .style 'fill', (d,ndx) -> colors(ndx)
        .attr  'd', (d) -> "M#{d.join(&#39;L&#39;)}Z"

    # Bin points based on centroids O(n * k)
    centroidBins = [1..k].map (d) -> []
    for point in points
      minDist = 100
      for centroid, centroidNdx in centroids
        if (d = distance(point, centroid)) < minDist
          minDist = d
          minNdx  = centroidNdx
      centroidBins[minNdx].push point
    
    # Find new centroids
    for bin, binNdx in centroidBins
      newCentroid = avgXY(bin)
      centroids[binNdx] = newCentroid
    centroidsGroup.selectAll('circle')
      .data(centroids)
      .transition()
        .attr  'cx', (d) -> x(d[0])
        .attr  'cy', (d) -> y(d[1])
initProblem()


Tweet
Kmeans preview

© Chris Polis, 2012 - 2016

GitHub · Twitter · LinkedIn · Stack Overflow · Quora