Saturday, January 23, 2010

Space partitioning with a KD-Tree

More than once I've had to compare world space position data with vertex positions on a mesh object. In Maya there's a handy closestPointOnMesh node, but if you're dealing with a set of data instead of an existing object you're kind of stuck. Then recently I was introduced to the concept of space partitioning. After studying several different types I settled on the kd-tree as being my favorite one (octree being a close second) when dealing with 3d data.

In a nutshell, to construct a kd-tree out of 3d points you simply divide your data into two sets along one dimension(x-axis), then in the next dimension(y-axis) you divide each of those sets of data, and so on cycling through each dimension until the data sets contain only one point. Perhaps the wiki page can do a better job explaining than me.

http://en.wikipedia.org/wiki/Kd-tree

So why does dividing the data up this way make looking up points quicker? Well, instead of having to iterate over all points to determine which is closest, we start at the top of the tree and move to the next appropriate branch until your reach the bottom. You end up having to do a fraction of the number of calculations which make this method very fast.

Here is an example of constructing and using a kd-tree in Maya using python. Select two mesh objects and execute the following code in the python script editor. After you've executed the script, the "lCmpPoints" python list will contain the closest point on the first mesh for each vertex on the second mesh. You can check the results by running lCmpPoints[100] for any index in the script editor.

www.stevenkalinowski.com/scripts/sak_kdtree.zip

import maya.OpenMaya as OpenMaya
import maya.cmds as cmds
import time

# Here I am using a kdtree to store mesh points, this makes looking up the closest vertex
# to a given worldspace position very fast

# Point class defines a vertex, simply index(int) and 3d position(tuple) attributes
class Point(object):
   def __init__(self, i, pos):
      self.index = i
      self.position = pos
   def __str__(self):
      return str( 'Vertex:%d at (%4.3f, %4.3f, %4.3f)' %
         (self.index, self.position[0], self.position[1], self.position[2]) )
   def __repr__(self):
      return repr(str(self))
   # The axisAverage method will average one axis of a list of positions
   # pVtx - list of tuples
   # pAxis - int index of axis, 0 for X, 1 for Y, and 2 for Z
   @staticmethod
   def axisAverage(pVtx, pAxis):
      lTotal = 0.0
      for v in pVtx:
         lTotal += v.position[pAxis]
      return lTotal / len(pVtx)

class KdTree(object):
   # 3d space so we have 3 dimensions
   kDimensions = 3
   # Node objects hold the points and child nodes, these make up the tree
   class Node(object):
      def __init__(self, pOrigin, pPts):
         # Floating point number of location of axis division
         self.origin = pOrigin
         # Python list of tuples
         self.points = pPts
         # Child Node objects
         self.leftNode = None
         self.rightNode = None

   # The KdTree class is instantiated with a list of tuples for its points
   def __init__(self, pPoints):
      # Create the tree by calling the recursive method addPoints
      self.node = KdTree.addPoints(pPoints)

   # This recursive method will create the entire tree, it will continuously
   # split the list of points into child Nodes until there is only one left
   @staticmethod
   def addPoints(pPoints, depth=0):
      if not pPoints:
         return
      # On each recursion we cycle through each axis
      lAxis = depth % KdTree.kDimensions
      # Sort the points list according to the axis we are splitting along
      pPoints.sort(lambda x,y:cmp(x.position[lAxis], y.position[lAxis]))
      # Find the mean of all points in the list for this particular axis
      lOriginAxis = Point.axisAverage(pPoints, lAxis)
      # Split the points into two lists, one for each side of the axis
      lLeftPts = [x for x in pPoints if x.position[lAxis] <= lOriginAxis]
      lRightPts = [x for x in pPoints if x.position[lAxis] > lOriginAxis]
      # Create the Node object
      lNode = KdTree.Node(lOriginAxis, pPoints)

      # Now create the child nodes
      # If either list is zero, we do not need to create child nodes because
      # the other list will only contain one point meaning it cannot be split
      # anymore
      if len(lLeftPts) > 0 and len(lRightPts) > 0:
         lNode.leftNode = KdTree.addPoints(lLeftPts, depth+1)
         lNode.rightNode = KdTree.addPoints(lRightPts, depth+1)

      return lNode

   # The closestPoint method is called to look up the closest point in the tree
   # to a given position in world space
   def closestPoint(self, pPos, depth=0, node=None):
      # The first time this method is called, we start with the parent node
      if not node:
         node = self.node

      # On each recursion we cycle through each axis
      lAxis = depth % KdTree.kDimensions

      # If there are no children nodes there must be only one point in this
      # Node, so we return it
      if not node.leftNode or not node.rightNode:
         return node.points[0]

      # Check which side of the split axis we are on, the with the appropriate
      # child Node we call this method again
      if pPos[lAxis] <= node.origin:
         lPt = self.closestPoint(pPos, depth+1, node.leftNode)
      else:
         lPt = self.closestPoint(pPos, depth+1, node.rightNode)

      # Return the last Point
      return lPt

# Now we'll do a test by creating a kdtree with one mesh, then look up the closest
# postion for each vertex in another mesh
# Select two meshes, first the one to make the tree, then the one to compare
lSelList = OpenMaya.MSelectionList();
OpenMaya.MGlobal.getActiveSelectionList(lSelList)

if lSelList.length() == 2:
   lMeshPath = OpenMaya.MDagPath()
   lCmpMeshPath = OpenMaya.MDagPath()

   lSelList.getDagPath(0, lMeshPath)
   lSelList.getDagPath(1, lCmpMeshPath)
   # Make sure both objects are meshes
   if lMeshPath.hasFn(OpenMaya.MFn.kMesh) and lCmpMeshPath.hasFn(OpenMaya.MFn.kMesh):
      lMeshIter = OpenMaya.MItMeshVertex(lMeshPath)
      # Python list of Point objects
      lPoints = []

      while not lMeshIter.isDone():
         lPos = lMeshIter.position(OpenMaya.MSpace.kWorld)
         # Add each vertex to the list as a Point object
         lPoints.append( Point(lMeshIter.index(), (lPos.x, lPos.y, lPos.z)) )
         lMeshIter.next()
      # Create the tree
      lKdTree = KdTree(lPoints)

      # Now we'll time how long it takes to iterate the second mesh looking up each vertex's
      # closest point
      lStartTime = time.time()
      lCmpMeshIter = OpenMaya.MItMeshVertex(lCmpMeshPath)
      lCmpPoints = []

      while not lCmpMeshIter.isDone():
         lCmpPos = lCmpMeshIter.position(OpenMaya.MSpace.kWorld)
         # Find the closest point and append it to the list
         lCmpPoints.append( lKdTree.closestPoint( (lCmpPos.x, lCmpPos.y, lCmpPos.z) ) )

         lCmpMeshIter.next()

      lTimeTaken = time.time() - lStartTime
      # Report how long it took
      print lTimeTaken, 'Using KdTree'