In a nutshell, to construct a kd-tree out of 3d points you simply divide your data into two sets along one dimension(x-axis), then in the next dimension(y-axis) you divide each of those sets of data, and so on cycling through each dimension until the data sets contain only one point. Perhaps the wiki page can do a better job explaining than me.
http://en.wikipedia.org/wiki/Kd-tree
So why does dividing the data up this way make looking up points quicker? Well, instead of having to iterate over all points to determine which is closest, we start at the top of the tree and move to the next appropriate branch until your reach the bottom. You end up having to do a fraction of the number of calculations which make this method very fast.
Here is an example of constructing and using a kd-tree in Maya using python. Select two mesh objects and execute the following code in the python script editor. After you've executed the script, the "lCmpPoints" python list will contain the closest point on the first mesh for each vertex on the second mesh. You can check the results by running lCmpPoints[100
www.stevenkalinowski.com/scripts/sak_kdtree.zip
import maya.OpenMaya as OpenMaya import maya.cmds as cmds import time # Here I am using a kdtree to store mesh points, this makes looking up the closest vertex # to a given worldspace position very fast # Point class defines a vertex, simply index(int) and 3d position(tuple) attributes class Point(object): def __init__(self, i, pos): self.index = i self.position = pos def __str__(self): return str( 'Vertex:%d at (%4.3f, %4.3f, %4.3f)' % (self.index, self.position[0], self.position[1], self.position[2]) ) def __repr__(self): return repr(str(self)) # The axisAverage method will average one axis of a list of positions # pVtx - list of tuples # pAxis - int index of axis, 0 for X, 1 for Y, and 2 for Z @staticmethod def axisAverage(pVtx, pAxis): lTotal = 0.0 for v in pVtx: lTotal += v.position[pAxis] return lTotal / len(pVtx) class KdTree(object): # 3d space so we have 3 dimensions kDimensions = 3 # Node objects hold the points and child nodes, these make up the tree class Node(object): def __init__(self, pOrigin, pPts): # Floating point number of location of axis division self.origin = pOrigin # Python list of tuples self.points = pPts # Child Node objects self.leftNode = None self.rightNode = None # The KdTree class is instantiated with a list of tuples for its points def __init__(self, pPoints): # Create the tree by calling the recursive method addPoints self.node = KdTree.addPoints(pPoints) # This recursive method will create the entire tree, it will continuously # split the list of points into child Nodes until there is only one left @staticmethod def addPoints(pPoints, depth=0): if not pPoints: return # On each recursion we cycle through each axis lAxis = depth % KdTree.kDimensions # Sort the points list according to the axis we are splitting along pPoints.sort(lambda x,y:cmp(x.position[lAxis], y.position[lAxis])) # Find the mean of all points in the list for this particular axis lOriginAxis = Point.axisAverage(pPoints, lAxis) # Split the points into two lists, one for each side of the axis lLeftPts = [x for x in pPoints if x.position[lAxis] <= lOriginAxis] lRightPts = [x for x in pPoints if x.position[lAxis] > lOriginAxis] # Create the Node object lNode = KdTree.Node(lOriginAxis, pPoints) # Now create the child nodes # If either list is zero, we do not need to create child nodes because # the other list will only contain one point meaning it cannot be split # anymore if len(lLeftPts) > 0 and len(lRightPts) > 0: lNode.leftNode = KdTree.addPoints(lLeftPts, depth+1) lNode.rightNode = KdTree.addPoints(lRightPts, depth+1) return lNode # The closestPoint method is called to look up the closest point in the tree # to a given position in world space def closestPoint(self, pPos, depth=0, node=None): # The first time this method is called, we start with the parent node if not node: node = self.node # On each recursion we cycle through each axis lAxis = depth % KdTree.kDimensions # If there are no children nodes there must be only one point in this # Node, so we return it if not node.leftNode or not node.rightNode: return node.points[0] # Check which side of the split axis we are on, the with the appropriate # child Node we call this method again if pPos[lAxis] <= node.origin: lPt = self.closestPoint(pPos, depth+1, node.leftNode) else: lPt = self.closestPoint(pPos, depth+1, node.rightNode) # Return the last Point return lPt # Now we'll do a test by creating a kdtree with one mesh, then look up the closest # postion for each vertex in another mesh # Select two meshes, first the one to make the tree, then the one to compare lSelList = OpenMaya.MSelectionList(); OpenMaya.MGlobal.getActiveSelectionList(lSelList) if lSelList.length() == 2: lMeshPath = OpenMaya.MDagPath() lCmpMeshPath = OpenMaya.MDagPath() lSelList.getDagPath(0, lMeshPath) lSelList.getDagPath(1, lCmpMeshPath) # Make sure both objects are meshes if lMeshPath.hasFn(OpenMaya.MFn.kMesh) and lCmpMeshPath.hasFn(OpenMaya.MFn.kMesh): lMeshIter = OpenMaya.MItMeshVertex(lMeshPath) # Python list of Point objects lPoints = [] while not lMeshIter.isDone(): lPos = lMeshIter.position(OpenMaya.MSpace.kWorld) # Add each vertex to the list as a Point object lPoints.append( Point(lMeshIter.index(), (lPos.x, lPos.y, lPos.z)) ) lMeshIter.next() # Create the tree lKdTree = KdTree(lPoints) # Now we'll time how long it takes to iterate the second mesh looking up each vertex's # closest point lStartTime = time.time() lCmpMeshIter = OpenMaya.MItMeshVertex(lCmpMeshPath) lCmpPoints = [] while not lCmpMeshIter.isDone(): lCmpPos = lCmpMeshIter.position(OpenMaya.MSpace.kWorld) # Find the closest point and append it to the list lCmpPoints.append( lKdTree.closestPoint( (lCmpPos.x, lCmpPos.y, lCmpPos.z) ) ) lCmpMeshIter.next() lTimeTaken = time.time() - lStartTime # Report how long it took print lTimeTaken, 'Using KdTree'