
| Current Path : /var/www/web-klick.de/dsh/50_dev2017/1310__algorithms/Julia/Notebooks/ |
Linux ift1.ift-informatik.de 5.4.0-216-generic #236-Ubuntu SMP Fri Apr 11 19:53:21 UTC 2025 x86_64 |
| Current File : /var/www/web-klick.de/dsh/50_dev2017/1310__algorithms/Julia/Notebooks/KDTree.ipynb |
{
"metadata": {
"name": ""
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"from heapq import heappush, heappop\n",
"import scipy.sparse"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def minkowski_distance_p(x,y,p=2):\n",
" \"\"\"Compute the pth power of the L**p distance between x and y\n",
" \n",
" For efficiency, this function computes the L**p distance but does\n",
" not extract the pth root. If p is 1 or infinity, this is equal to\n",
" the actual L**p distance.\n",
" \"\"\"\n",
" x = np.asarray(x)\n",
" y = np.asarray(y)\n",
" if p==np.inf:\n",
" return np.amax(np.abs(y-x),axis=-1)\n",
" elif p==1:\n",
" return np.sum(np.abs(y-x),axis=-1)\n",
" else:\n",
" return np.sum(np.abs(y-x)**p,axis=-1)\n",
"def minkowski_distance(x,y,p=2):\n",
" \"\"\"Compute the L**p distance between x and y\"\"\"\n",
" x = np.asarray(x)\n",
" y = np.asarray(y)\n",
" if p==np.inf or p==1:\n",
" return minkowski_distance_p(x,y,p)\n",
" else:\n",
" return minkowski_distance_p(x,y,p)**(1./p)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"class Rectangle(object):\n",
" \"\"\"Hyperrectangle class.\n",
" \n",
" Represents a Cartesian product of intervals.\n",
" \"\"\"\n",
" def __init__(self, maxes, mins):\n",
" \"\"\"Construct a hyperrectangle.\"\"\"\n",
" self.maxes = np.maximum(maxes,mins).astype(np.float)\n",
" self.mins = np.minimum(maxes,mins).astype(np.float)\n",
" self.m, = self.maxes.shape\n",
" \n",
" def __repr__(self):\n",
" return \"<Rectangle %s>\" % zip(self.mins, self.maxes)\n",
" \n",
" def volume(self):\n",
" \"\"\"Total volume.\"\"\"\n",
" return np.prod(self.maxes-self.mins)\n",
" \n",
" def split(self, d, split):\n",
" \"\"\"Produce two hyperrectangles by splitting along axis d.\n",
" \n",
" In general, if you need to compute maximum and minimum\n",
" distances to the children, it can be done more efficiently\n",
" by updating the maximum and minimum distances to the parent.\n",
" \"\"\" # FIXME: do this\n",
" mid = np.copy(self.maxes)\n",
" mid[d] = split\n",
" less = Rectangle(self.mins, mid)\n",
" mid = np.copy(self.mins)\n",
" mid[d] = split\n",
" greater = Rectangle(mid, self.maxes)\n",
" return less, greater\n",
" \n",
" def min_distance_point(self, x, p=2.):\n",
" \"\"\"Compute the minimum distance between x and a point in the hyperrectangle.\"\"\"\n",
" return minkowski_distance(0, np.maximum(0,np.maximum(self.mins-x,x-self.maxes)),p)\n",
" \n",
" def max_distance_point(self, x, p=2.):\n",
" \"\"\"Compute the maximum distance between x and a point in the hyperrectangle.\"\"\"\n",
" return minkowski_distance(0, np.maximum(self.maxes-x,x-self.mins),p)\n",
" \n",
" def min_distance_rectangle(self, other, p=2.):\n",
" \"\"\"Compute the minimum distance between points in the two hyperrectangles.\"\"\"\n",
" return minkowski_distance(0, np.maximum(0,np.maximum(self.mins-other.maxes,other.mins-self.maxes)),p)\n",
" \n",
" def max_distance_rectangle(self, other, p=2.):\n",
" \"\"\"Compute the maximum distance between points in the two hyperrectangles.\"\"\"\n",
" return minkowski_distance(0, np.maximum(self.maxes-other.mins,other.maxes-self.mins),p)\n",
" "
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"class KDTree(object):\n",
" \"\"\"kd-tree for quick nearest-neighbor lookup\n",
" \n",
" This class provides an index into a set of k-dimensional points\n",
" which can be used to rapidly look up the nearest neighbors of any\n",
" point.\n",
" \n",
" The algorithm used is described in Maneewongvatana and Mount 1999.\n",
" The general idea is that the kd-tree is a binary trie, each of whose\n",
" nodes represents an axis-aligned hyperrectangle. Each node specifies\n",
" an axis and splits the set of points based on whether their coordinate\n",
" along that axis is greater than or less than a particular value.\n",
" \n",
" During construction, the axis and splitting point are chosen by the\n",
" \"sliding midpoint\" rule, which ensures that the cells do not all\n",
" become long and thin.\n",
" \n",
" The tree can be queried for the r closest neighbors of any given point\n",
" (optionally returning only those within some maximum distance of the\n",
" point). It can also be queried, with a substantial gain in efficiency,\n",
" for the r approximate closest neighbors.\n",
" \n",
" For large dimensions (20 is already large) do not expect this to run\n",
" significantly faster than brute force. High-dimensional nearest-neighbor\n",
" queries are a substantial open problem in computer science.\n",
" \n",
" The tree also supports all-neighbors queries, both with arrays of points\n",
" and with other kd-trees. These do use a reasonably efficient algorithm,\n",
" but the kd-tree is not necessarily the best data structure for this\n",
" sort of calculation.\n",
" \"\"\"\n",
" \n",
" def __init__(self, data, leafsize=10):\n",
" \"\"\"Construct a kd-tree.\n",
" \n",
" Parameters:\n",
" ===========\n",
" \n",
" data : array-like, shape (n,k)\n",
" The data points to be indexed. This array is not copied, and\n",
" so modifying this data will result in bogus results.\n",
" leafsize : positive integer\n",
" The number of points at which the algorithm switches over to\n",
" brute-force.\n",
" \"\"\"\n",
" self.data = np.asarray(data)\n",
" self.n, self.m = np.shape(self.data)\n",
" self.leafsize = int(leafsize)\n",
" if self.leafsize<1:\n",
" raise ValueError(\"leafsize must be at least 1\")\n",
" self.maxes = np.amax(self.data,axis=0)\n",
" self.mins = np.amin(self.data,axis=0)\n",
" \n",
" self.tree = self.__build(np.arange(self.n), self.maxes, self.mins)\n",
" \n",
" class node(object):\n",
" pass\n",
" class leafnode(node):\n",
" def __init__(self, idx):\n",
" self.idx = idx\n",
" self.children = len(idx)\n",
" class innernode(node):\n",
" def __init__(self, split_dim, split, less, greater):\n",
" self.split_dim = split_dim\n",
" self.split = split\n",
" self.less = less\n",
" self.greater = greater\n",
" self.children = less.children+greater.children\n",
" \n",
" def __build(self, idx, maxes, mins):\n",
" if len(idx)<=self.leafsize:\n",
" return KDTree.leafnode(idx)\n",
" else:\n",
" data = self.data[idx]\n",
" #maxes = np.amax(data,axis=0)\n",
" #mins = np.amin(data,axis=0)\n",
" d = np.argmax(maxes-mins)\n",
" maxval = maxes[d]\n",
" minval = mins[d]\n",
" if maxval==minval:\n",
" # all points are identical; warn user?\n",
" return KDTree.leafnode(idx)\n",
" data = data[:,d]\n",
" \n",
" # sliding midpoint rule; see Maneewongvatana and Mount 1999\n",
" # for arguments that this is a good idea.\n",
" split = (maxval+minval)/2\n",
" less_idx = np.nonzero(data<=split)[0]\n",
" greater_idx = np.nonzero(data>split)[0]\n",
" if len(less_idx)==0:\n",
" split = np.amin(data)\n",
" less_idx = np.nonzero(data<=split)[0]\n",
" greater_idx = np.nonzero(data>split)[0]\n",
" if len(greater_idx)==0:\n",
" split = np.amax(data)\n",
" less_idx = np.nonzero(data<split)[0]\n",
" greater_idx = np.nonzero(data>=split)[0]\n",
" if len(less_idx)==0:\n",
" # _still_ zero? all must have the same value\n",
" assert np.all(data==data[0]), \"Troublesome data array: %s\" % data\n",
" split = data[0]\n",
" less_idx = np.arange(len(data)-1)\n",
" greater_idx = np.array([len(data)-1])\n",
" \n",
" lessmaxes = np.copy(maxes)\n",
" lessmaxes[d] = split\n",
" greatermins = np.copy(mins)\n",
" greatermins[d] = split\n",
" return KDTree.innernode(d, split,\n",
" self.__build(idx[less_idx],lessmaxes,mins),\n",
" self.__build(idx[greater_idx],maxes,greatermins))\n",
" \n",
" def __query(self, x, k=1, eps=0, p=2, distance_upper_bound=np.inf):\n",
" \n",
" side_distances = np.maximum(0,np.maximum(x-self.maxes,self.mins-x))\n",
" if p!=np.inf:\n",
" side_distances**=p\n",
" min_distance = np.sum(side_distances)\n",
" else:\n",
" min_distance = np.amax(side_distances)\n",
" \n",
" # priority queue for chasing nodes\n",
" # entries are:\n",
" # minimum distance between the cell and the target\n",
" # distances between the nearest side of the cell and the target\n",
" # the head node of the cell\n",
" q = [(min_distance,\n",
" tuple(side_distances),\n",
" self.tree)]\n",
" # priority queue for the nearest neighbors\n",
" # furthest known neighbor first\n",
" # entries are (-distance**p, i)\n",
" neighbors = []\n",
" \n",
" if eps==0:\n",
" epsfac=1\n",
" elif p==np.inf:\n",
" epsfac = 1/(1+eps)\n",
" else:\n",
" epsfac = 1/(1+eps)**p\n",
" \n",
" if p!=np.inf and distance_upper_bound!=np.inf:\n",
" distance_upper_bound = distance_upper_bound**p\n",
" \n",
" while q:\n",
" min_distance, side_distances, node = heappop(q)\n",
" if isinstance(node, KDTree.leafnode):\n",
" # brute-force\n",
" data = self.data[node.idx]\n",
" ds = minkowski_distance_p(data,x[np.newaxis,:],p)\n",
" for i in range(len(ds)):\n",
" if ds[i]<distance_upper_bound:\n",
" if len(neighbors)==k:\n",
" heappop(neighbors)\n",
" heappush(neighbors, (-ds[i], node.idx[i]))\n",
" if len(neighbors)==k:\n",
" distance_upper_bound = -neighbors[0][0]\n",
" else:\n",
" # we don't push cells that are too far onto the queue at all,\n",
" # but since the distance_upper_bound decreases, we might get\n",
" # here even if the cell's too far\n",
" if min_distance>distance_upper_bound*epsfac:\n",
" # since this is the nearest cell, we're done, bail out\n",
" break\n",
" # compute minimum distances to the children and push them on\n",
" if x[node.split_dim]<node.split:\n",
" near, far = node.less, node.greater\n",
" else:\n",
" near, far = node.greater, node.less\n",
" \n",
" # near child is at the same distance as the current node\n",
" heappush(q,(min_distance, side_distances, near))\n",
" \n",
" # far child is further by an amount depending only\n",
" # on the split value\n",
" sd = list(side_distances)\n",
" if p == np.inf:\n",
" min_distance = max(min_distance, abs(node.split-x[node.split_dim]))\n",
" elif p == 1:\n",
" sd[node.split_dim] = np.abs(node.split-x[node.split_dim])\n",
" min_distance = min_distance - side_distances[node.split_dim] + sd[node.split_dim]\n",
" else:\n",
" sd[node.split_dim] = np.abs(node.split-x[node.split_dim])**p\n",
" min_distance = min_distance - side_distances[node.split_dim] + sd[node.split_dim]\n",
" \n",
" # far child might be too far, if so, don't bother pushing it\n",
" if min_distance<=distance_upper_bound*epsfac:\n",
" heappush(q,(min_distance, tuple(sd), far))\n",
" \n",
" if p==np.inf:\n",
" return sorted([(-d,i) for (d,i) in neighbors])\n",
" else:\n",
" return sorted([((-d)**(1./p),i) for (d,i) in neighbors])\n",
" \n",
" def query(self, x, k=1, eps=0, p=2, distance_upper_bound=np.inf):\n",
" \"\"\"query the kd-tree for nearest neighbors\n",
" \n",
" Parameters:\n",
" ===========\n",
" \n",
" x : array-like, last dimension self.m\n",
" An array of points to query.\n",
" k : integer\n",
" The number of nearest neighbors to return.\n",
" eps : nonnegative float\n",
" Return approximate nearest neighbors; the kth returned value\n",
" is guaranteed to be no further than (1+eps) times the\n",
" distance to the real kth nearest neighbor.\n",
" p : float, 1<=p<=infinity\n",
" Which Minkowski p-norm to use.\n",
" 1 is the sum-of-absolute-values \"Manhattan\" distance\n",
" 2 is the usual Euclidean distance\n",
" infinity is the maximum-coordinate-difference distance\n",
" distance_upper_bound : nonnegative float\n",
" Return only neighbors within this distance. This is used to prune\n",
" tree searches, so if you are doing a series of nearest-neighbor\n",
" queries, it may help to supply the distance to the nearest neighbor\n",
" of the most recent point.\n",
" \n",
" Returns:\n",
" ========\n",
" \n",
" d : array of floats\n",
" The distances to the nearest neighbors.\n",
" If x has shape tuple+(self.m,), then d has shape tuple if\n",
" k is one, or tuple+(k,) if k is larger than one. Missing\n",
" neighbors are indicated with infinite distances. If k is None,\n",
" then d is an object array of shape tuple, containing lists\n",
" of distances. In either case the hits are sorted by distance\n",
" (nearest first).\n",
" i : array of integers\n",
" The locations of the neighbors in self.data. i is the same\n",
" shape as d.\n",
" \"\"\"\n",
" x = np.asarray(x)\n",
" if np.shape(x)[-1] != self.m:\n",
" raise ValueError(\"x must consist of vectors of length %d but has shape %s\" % (self.m, np.shape(x)))\n",
" if p<1:\n",
" raise ValueError(\"Only p-norms with 1<=p<=infinity permitted\")\n",
" retshape = np.shape(x)[:-1]\n",
" if retshape!=():\n",
" if k>1:\n",
" dd = np.empty(retshape+(k,),dtype=np.float)\n",
" dd.fill(np.inf)\n",
" ii = np.empty(retshape+(k,),dtype=np.int)\n",
" ii.fill(self.n)\n",
" elif k==1:\n",
" dd = np.empty(retshape,dtype=np.float)\n",
" dd.fill(np.inf)\n",
" ii = np.empty(retshape,dtype=np.int)\n",
" ii.fill(self.n)\n",
" elif k is None:\n",
" dd = np.empty(retshape,dtype=np.object)\n",
" ii = np.empty(retshape,dtype=np.object)\n",
" else:\n",
" raise ValueError(\"Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None\")\n",
" for c in np.ndindex(retshape):\n",
" hits = self.__query(x[c], k=k, p=p, distance_upper_bound=distance_upper_bound)\n",
" if k>1:\n",
" for j in range(len(hits)):\n",
" dd[c+(j,)], ii[c+(j,)] = hits[j]\n",
" elif k==1:\n",
" if len(hits)>0:\n",
" dd[c], ii[c] = hits[0]\n",
" else:\n",
" dd[c] = np.inf\n",
" ii[c] = self.n\n",
" elif k is None:\n",
" dd[c] = [d for (d,i) in hits]\n",
" ii[c] = [i for (d,i) in hits]\n",
" return dd, ii\n",
" else:\n",
" hits = self.__query(x, k=k, p=p, distance_upper_bound=distance_upper_bound)\n",
" if k==1:\n",
" if len(hits)>0:\n",
" return hits[0]\n",
" else:\n",
" return np.inf, self.n\n",
" elif k>1:\n",
" dd = np.empty(k,dtype=np.float)\n",
" dd.fill(np.inf)\n",
" ii = np.empty(k,dtype=np.int)\n",
" ii.fill(self.n)\n",
" for j in range(len(hits)):\n",
" dd[j], ii[j] = hits[j]\n",
" return dd, ii\n",
" elif k is None:\n",
" return [d for (d,i) in hits], [i for (d,i) in hits]\n",
" else:\n",
" raise ValueError(\"Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None\")\n",
" \n",
" \n",
" def __query_ball_point(self, x, r, p=2., eps=0):\n",
" R = Rectangle(self.maxes, self.mins)\n",
" \n",
" def traverse_checking(node, rect):\n",
" if rect.min_distance_point(x,p)>=r/(1.+eps):\n",
" return []\n",
" elif rect.max_distance_point(x,p)<r*(1.+eps):\n",
" return traverse_no_checking(node)\n",
" elif isinstance(node, KDTree.leafnode):\n",
" d = self.data[node.idx]\n",
" return node.idx[minkowski_distance(d,x,p)<=r].tolist()\n",
" else:\n",
" less, greater = rect.split(node.split_dim, node.split)\n",
" return traverse_checking(node.less, less)+traverse_checking(node.greater, greater)\n",
" def traverse_no_checking(node):\n",
" if isinstance(node, KDTree.leafnode):\n",
" \n",
" return node.idx.tolist()\n",
" else:\n",
" return traverse_no_checking(node.less)+traverse_no_checking(node.greater)\n",
" \n",
" return traverse_checking(self.tree, R)\n",
" \n",
" def query_ball_point(self, x, r, p=2., eps=0):\n",
" \"\"\"Find all points within r of x\n",
" \n",
" Parameters\n",
" ==========\n",
" \n",
" x : array_like, shape tuple + (self.m,)\n",
" The point or points to search for neighbors of\n",
" r : positive float\n",
" The radius of points to return\n",
" p : float 1<=p<=infinity\n",
" Which Minkowski p-norm to use\n",
" eps : nonnegative float\n",
" Approximate search. Branches of the tree are not explored\n",
" if their nearest points are further than r/(1+eps), and branches\n",
" are added in bulk if their furthest points are nearer than r*(1+eps).\n",
" \n",
" Returns\n",
" =======\n",
" \n",
" results : list or array of lists\n",
" If x is a single point, returns a list of the indices of the neighbors\n",
" of x. If x is an array of points, returns an object array of shape tuple\n",
" containing lists of neighbors.\n",
" \n",
" \n",
" Note: if you have many points whose neighbors you want to find, you may save\n",
" substantial amounts of time by putting them in a KDTree and using query_ball_tree.\n",
" \"\"\"\n",
" x = np.asarray(x)\n",
" if x.shape[-1]!=self.m:\n",
" raise ValueError(\"Searching for a %d-dimensional point in a %d-dimensional KDTree\" % (x.shape[-1],self.m))\n",
" if len(x.shape)==1:\n",
" return self.__query_ball_point(x,r,p,eps)\n",
" else:\n",
" retshape = x.shape[:-1]\n",
" result = np.empty(retshape,dtype=np.object)\n",
" for c in np.ndindex(retshape):\n",
" result[c] = self.__query_ball_point(x[c], r, p=p, eps=eps)\n",
" return result\n",
" \n",
" def query_ball_tree(self, other, r, p=2., eps=0):\n",
" \"\"\"Find all pairs of points whose distance is at most r\n",
" \n",
" Parameters\n",
" ==========\n",
" \n",
" other : KDTree\n",
" The tree containing points to search against\n",
" r : positive float\n",
" The maximum distance\n",
" p : float 1<=p<=infinity\n",
" Which Minkowski norm to use\n",
" eps : nonnegative float\n",
" Approximate search. Branches of the tree are not explored\n",
" if their nearest points are further than r/(1+eps), and branches\n",
" are added in bulk if their furthest points are nearer than r*(1+eps).\n",
" \n",
" Returns\n",
" =======\n",
" \n",
" results : list of lists\n",
" For each element self.data[i] of this tree, results[i] is a list of the\n",
" indices of its neighbors in other.data.\n",
" \"\"\"\n",
" results = [[] for i in range(self.n)]\n",
" def traverse_checking(node1, rect1, node2, rect2):\n",
" if rect1.min_distance_rectangle(rect2, p)>r/(1.+eps):\n",
" return\n",
" elif rect1.max_distance_rectangle(rect2, p)<r*(1.+eps):\n",
" traverse_no_checking(node1, node2)\n",
" elif isinstance(node1, KDTree.leafnode):\n",
" if isinstance(node2, KDTree.leafnode):\n",
" d = other.data[node2.idx]\n",
" for i in node1.idx:\n",
" results[i] += node2.idx[minkowski_distance(d,self.data[i],p)<=r].tolist()\n",
" else:\n",
" less, greater = rect2.split(node2.split_dim, node2.split)\n",
" traverse_checking(node1,rect1,node2.less,less)\n",
" traverse_checking(node1,rect1,node2.greater,greater)\n",
" elif isinstance(node2, KDTree.leafnode):\n",
" less, greater = rect1.split(node1.split_dim, node1.split)\n",
" traverse_checking(node1.less,less,node2,rect2)\n",
" traverse_checking(node1.greater,greater,node2,rect2)\n",
" else:\n",
" less1, greater1 = rect1.split(node1.split_dim, node1.split)\n",
" less2, greater2 = rect2.split(node2.split_dim, node2.split)\n",
" traverse_checking(node1.less,less1,node2.less,less2)\n",
" traverse_checking(node1.less,less1,node2.greater,greater2)\n",
" traverse_checking(node1.greater,greater1,node2.less,less2)\n",
" traverse_checking(node1.greater,greater1,node2.greater,greater2)\n",
" \n",
" def traverse_no_checking(node1, node2):\n",
" if isinstance(node1, KDTree.leafnode):\n",
" if isinstance(node2, KDTree.leafnode):\n",
" for i in node1.idx:\n",
" results[i] += node2.idx.tolist()\n",
" else:\n",
" traverse_no_checking(node1, node2.less)\n",
" traverse_no_checking(node1, node2.greater)\n",
" else:\n",
" traverse_no_checking(node1.less, node2)\n",
" traverse_no_checking(node1.greater, node2)\n",
" \n",
" traverse_checking(self.tree, Rectangle(self.maxes, self.mins),\n",
" other.tree, Rectangle(other.maxes, other.mins))\n",
" return results\n",
" \n",
" \n",
" def count_neighbors(self, other, r, p=2.):\n",
" \"\"\"Count how many nearby pairs can be formed.\n",
" \n",
" Count the number of pairs (x1,x2) can be formed, with x1 drawn\n",
" from self and x2 drawn from other, and where distance(x1,x2,p)<=r.\n",
" This is the \"two-point correlation\" described in Gray and Moore 2000,\n",
" \"N-body problems in statistical learning\", and the code here is based\n",
" on their algorithm.\n",
"\n",
" \"\"\"\n",
" \n",
" def traverse(node1, rect1, node2, rect2, idx):\n",
" min_r = rect1.min_distance_rectangle(rect2,p)\n",
" max_r = rect1.max_distance_rectangle(rect2,p)\n",
" c_greater = r[idx]>max_r\n",
" result[idx[c_greater]] += node1.children*node2.children\n",
" idx = idx[(min_r<=r[idx]) & (r[idx]<=max_r)]\n",
" if len(idx)==0:\n",
" return\n",
" \n",
" if isinstance(node1,KDTree.leafnode):\n",
" if isinstance(node2,KDTree.leafnode):\n",
" ds = minkowski_distance(self.data[node1.idx][:,np.newaxis,:],\n",
" other.data[node2.idx][np.newaxis,:,:],\n",
" p).ravel()\n",
" ds.sort()\n",
" result[idx] += np.searchsorted(ds,r[idx],side='right')\n",
" else:\n",
" less, greater = rect2.split(node2.split_dim, node2.split)\n",
" traverse(node1, rect1, node2.less, less, idx)\n",
" traverse(node1, rect1, node2.greater, greater, idx)\n",
" else:\n",
" if isinstance(node2,KDTree.leafnode):\n",
" less, greater = rect1.split(node1.split_dim, node1.split)\n",
" traverse(node1.less, less, node2, rect2, idx)\n",
" traverse(node1.greater, greater, node2, rect2, idx)\n",
" else:\n",
" less1, greater1 = rect1.split(node1.split_dim, node1.split)\n",
" less2, greater2 = rect2.split(node2.split_dim, node2.split)\n",
" traverse(node1.less,less1,node2.less,less2,idx)\n",
" traverse(node1.less,less1,node2.greater,greater2,idx)\n",
" traverse(node1.greater,greater1,node2.less,less2,idx)\n",
" traverse(node1.greater,greater1,node2.greater,greater2,idx)\n",
" R1 = Rectangle(self.maxes, self.mins)\n",
" R2 = Rectangle(other.maxes, other.mins)\n",
" if np.shape(r) == ():\n",
" r = np.array([r])\n",
" result = np.zeros(1,dtype=int)\n",
" traverse(self.tree, R1, other.tree, R2, np.arange(1))\n",
" return result[0]\n",
" elif len(np.shape(r))==1:\n",
" r = np.asarray(r)\n",
" n, = r.shape\n",
" result = np.zeros(n,dtype=int)\n",
" traverse(self.tree, R1, other.tree, R2, np.arange(n))\n",
" return result\n",
" else:\n",
" raise ValueError(\"r must be either a single value or a one-dimensional array of values\")\n",
" \n",
" def sparse_distance_matrix(self, other, max_distance, p=2.):\n",
" \"\"\"Compute a sparse distance matrix\n",
" \n",
" Computes a distance matrix between two KDTrees, leaving as zero\n",
" any distance greater than max_distance.\n",
"\n",
" \"\"\"\n",
" result = scipy.sparse.dok_matrix((self.n,other.n))\n",
" \n",
" def traverse(node1, rect1, node2, rect2):\n",
" if rect1.min_distance_rectangle(rect2, p)>max_distance:\n",
" return\n",
" elif isinstance(node1, KDTree.leafnode):\n",
" if isinstance(node2, KDTree.leafnode):\n",
" for i in node1.idx:\n",
" for j in node2.idx:\n",
" d = minkowski_distance(self.data[i],other.data[j],p)\n",
" if d<=max_distance:\n",
" result[i,j] = d\n",
" else:\n",
" less, greater = rect2.split(node2.split_dim, node2.split)\n",
" traverse(node1,rect1,node2.less,less)\n",
" traverse(node1,rect1,node2.greater,greater)\n",
" elif isinstance(node2, KDTree.leafnode):\n",
" less, greater = rect1.split(node1.split_dim, node1.split)\n",
" traverse(node1.less,less,node2,rect2)\n",
" traverse(node1.greater,greater,node2,rect2)\n",
" else:\n",
" less1, greater1 = rect1.split(node1.split_dim, node1.split)\n",
" less2, greater2 = rect2.split(node2.split_dim, node2.split)\n",
" traverse(node1.less,less1,node2.less,less2)\n",
" traverse(node1.less,less1,node2.greater,greater2)\n",
" traverse(node1.greater,greater1,node2.less,less2)\n",
" traverse(node1.greater,greater1,node2.greater,greater2)\n",
" traverse(self.tree, Rectangle(self.maxes, self.mins),\n",
" other.tree, Rectangle(other.maxes, other.mins))\n",
" \n",
" return result\n",
" \n",
" \n",
"def distance_matrix(x,y,p=2,threshold=1000000):\n",
" \"\"\"Compute the distance matrix.\n",
" \n",
" Computes the matrix of all pairwise distances.\n",
" \n",
" Parameters\n",
" ==========\n",
" \n",
" x : array-like, m by k\n",
" y : array-like, n by k\n",
" p : float 1<=p<=infinity\n",
" Which Minkowski p-norm to use.\n",
" threshold : positive integer\n",
" If m*n*k>threshold use a loop instead of creating\n",
" a very large temporary.\n",
" \n",
" Returns\n",
" =======\n",
" \n",
" result : array-like, m by n\n",
" \n",
" \n",
" \"\"\"\n",
" \n",
" x = np.asarray(x)\n",
" m, k = x.shape\n",
" y = np.asarray(y)\n",
" n, kk = y.shape\n",
" \n",
" if k != kk:\n",
" raise ValueError(\"x contains %d-dimensional vectors but y contains %d-dimensional vectors\" % (k, kk))\n",
" \n",
" if m*n*k <= threshold:\n",
" return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)\n",
" else:\n",
" result = np.empty((m,n),dtype=np.float) #FIXME: figure out the best dtype\n",
" if m<n:\n",
" for i in range(m):\n",
" result[i,:] = minkowski_distance(x[i],y,p)\n",
" else:\n",
" for j in range(n):\n",
" result[:,j] = minkowski_distance(x,y[j],p)\n",
" return result"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"n = 100\n",
"m = 5"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 5
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"data = np.random.randn(n, m)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 8
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"kdtree = KDTree(data)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 9
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"np.random.randint(len(data))"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 10,
"text": [
"59"
]
}
],
"prompt_number": 10
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"kdtree.query(data[20],3)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 11,
"text": [
"(array([ 0. , 1.34949919, 1.6719412 ]), array([20, 26, 3]))"
]
}
],
"prompt_number": 11
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"data[20]"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 12,
"text": [
"array([-0.08003008, 2.57704094, 0.13185724, -0.05317286, -1.54546621])"
]
}
],
"prompt_number": 12
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"kdtree.query([0.03772218, 1.37762555, 1.31047274, -0.49921709, -0.12849282], 5)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 13,
"text": [
"(array([ 1.01809915, 1.28248616, 1.29438119, 1.3101769 , 1.34424653]),\n",
" array([ 0, 24, 88, 96, 98]))"
]
}
],
"prompt_number": 13
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}