#include <linux/string.h>
 #include <linux/zalloc.h>
 
-static void __machine__remove_thread(struct machine *machine, struct thread *th, bool lock);
+static void __machine__remove_thread(struct machine *machine, struct thread_rb_node *nd,
+                                    struct thread *th, bool lock);
 static int append_inlines(struct callchain_cursor *cursor, struct map_symbol *ms, u64 ip);
 
 static struct dso *machine__kernel_dso(struct machine *machine)
        }
 }
 
+static int thread_rb_node__cmp_tid(const void *key, const struct rb_node *nd)
+{
+       int to_find = (int) *((pid_t *)key);
+
+       return to_find - (int)rb_entry(nd, struct thread_rb_node, rb_node)->thread->tid;
+}
+
+static struct thread_rb_node *thread_rb_node__find(const struct thread *th,
+                                                  struct rb_root *tree)
+{
+       struct rb_node *nd = rb_find(&th->tid, tree, thread_rb_node__cmp_tid);
+
+       return rb_entry(nd, struct thread_rb_node, rb_node);
+}
+
 static int machine__set_mmap_name(struct machine *machine)
 {
        if (machine__is_host(machine))
                down_write(&threads->lock);
                nd = rb_first_cached(&threads->entries);
                while (nd) {
-                       struct thread *t = rb_entry(nd, struct thread, rb_node);
+                       struct thread_rb_node *trb = rb_entry(nd, struct thread_rb_node, rb_node);
 
                        nd = rb_next(nd);
-                       __machine__remove_thread(machine, t, false);
+                       __machine__remove_thread(machine, trb, trb->thread, false);
                }
                up_write(&threads->lock);
        }
        struct rb_node **p = &threads->entries.rb_root.rb_node;
        struct rb_node *parent = NULL;
        struct thread *th;
+       struct thread_rb_node *nd;
        bool leftmost = true;
 
        th = threads__get_last_match(threads, machine, pid, tid);
 
        while (*p != NULL) {
                parent = *p;
-               th = rb_entry(parent, struct thread, rb_node);
+               th = rb_entry(parent, struct thread_rb_node, rb_node)->thread;
 
                if (th->tid == tid) {
                        threads__set_last_match(threads, th);
                return NULL;
 
        th = thread__new(pid, tid);
-       if (th != NULL) {
-               rb_link_node(&th->rb_node, parent, p);
-               rb_insert_color_cached(&th->rb_node, &threads->entries, leftmost);
+       if (th == NULL)
+               return NULL;
 
-               /*
-                * We have to initialize maps separately after rb tree is updated.
-                *
-                * The reason is that we call machine__findnew_thread
-                * within thread__init_maps to find the thread
-                * leader and that would screwed the rb tree.
-                */
-               if (thread__init_maps(th, machine)) {
-                       rb_erase_cached(&th->rb_node, &threads->entries);
-                       RB_CLEAR_NODE(&th->rb_node);
-                       thread__put(th);
-                       return NULL;
-               }
-               /*
-                * It is now in the rbtree, get a ref
-                */
-               thread__get(th);
-               threads__set_last_match(threads, th);
-               ++threads->nr;
+       nd = malloc(sizeof(*nd));
+       if (nd == NULL) {
+               thread__put(th);
+               return NULL;
+       }
+       nd->thread = th;
+
+       rb_link_node(&nd->rb_node, parent, p);
+       rb_insert_color_cached(&nd->rb_node, &threads->entries, leftmost);
+
+       /*
+        * We have to initialize maps separately after rb tree is updated.
+        *
+        * The reason is that we call machine__findnew_thread within
+        * thread__init_maps to find the thread leader and that would screwed
+        * the rb tree.
+        */
+       if (thread__init_maps(th, machine)) {
+               rb_erase_cached(&nd->rb_node, &threads->entries);
+               RB_CLEAR_NODE(&nd->rb_node);
+               free(nd);
+               thread__put(th);
+               return NULL;
        }
+       /*
+        * It is now in the rbtree, get a ref
+        */
+       thread__get(th);
+       threads__set_last_match(threads, th);
+       ++threads->nr;
 
        return th;
 }
 
                for (nd = rb_first_cached(&threads->entries); nd;
                     nd = rb_next(nd)) {
-                       struct thread *pos = rb_entry(nd, struct thread, rb_node);
+                       struct thread *pos = rb_entry(nd, struct thread_rb_node, rb_node)->thread;
 
                        ret += thread__fprintf(pos, fp);
                }
        return 0;
 }
 
-static void __machine__remove_thread(struct machine *machine, struct thread *th, bool lock)
+static void __machine__remove_thread(struct machine *machine, struct thread_rb_node *nd,
+                                    struct thread *th, bool lock)
 {
        struct threads *threads = machine__threads(machine, th->tid);
 
+       if (!nd)
+               nd = thread_rb_node__find(th, &threads->entries.rb_root);
+
        if (threads->last_match == th)
                threads__set_last_match(threads, NULL);
 
 
        BUG_ON(refcount_read(&th->refcnt) == 0);
 
-       rb_erase_cached(&th->rb_node, &threads->entries);
-       RB_CLEAR_NODE(&th->rb_node);
+       thread__put(nd->thread);
+       rb_erase_cached(&nd->rb_node, &threads->entries);
+       RB_CLEAR_NODE(&nd->rb_node);
        --threads->nr;
 
-       thread__put(th);
+       free(nd);
 
        if (lock)
                up_write(&threads->lock);
 
 void machine__remove_thread(struct machine *machine, struct thread *th)
 {
-       return __machine__remove_thread(machine, th, true);
+       return __machine__remove_thread(machine, NULL, th, true);
 }
 
 int machine__process_fork_event(struct machine *machine, union perf_event *event,
 {
        struct threads *threads;
        struct rb_node *nd;
-       struct thread *thread;
        int rc = 0;
        int i;
 
                threads = &machine->threads[i];
                for (nd = rb_first_cached(&threads->entries); nd;
                     nd = rb_next(nd)) {
-                       thread = rb_entry(nd, struct thread, rb_node);
-                       rc = fn(thread, priv);
+                       struct thread_rb_node *trb = rb_entry(nd, struct thread_rb_node, rb_node);
+
+                       rc = fn(trb->thread, priv);
                        if (rc != 0)
                                return rc;
                }