extern struct desc_ptr idt_descr;
 extern gate_desc idt_table[];
+extern struct desc_ptr nmi_idt_descr;
+extern gate_desc nmi_idt_table[];
 
 struct gdt_page {
        struct desc_struct gdt[GDT_ENTRIES];
        desc->limit = (limit >> 16) & 0xf;
 }
 
+#ifdef CONFIG_X86_64
+static inline void set_nmi_gate(int gate, void *addr)
+{
+       gate_desc s;
+
+       pack_gate(&s, GATE_INTERRUPT, (unsigned long)addr, 0, 0, __KERNEL_CS);
+       write_idt_entry(nmi_idt_table, gate, &s);
+}
+#endif
+
 static inline void _set_gate(int gate, unsigned type, void *addr,
                             unsigned dpl, unsigned ist, unsigned seg)
 {
 
 DECLARE_PER_CPU(unsigned int, irq_count);
 extern unsigned long kernel_eflags;
 extern asmlinkage void ignore_sysret(void);
+int is_debug_stack(unsigned long addr);
+void debug_stack_set_zero(void);
+void debug_stack_reset(void);
 #else  /* X86_64 */
 #ifdef CONFIG_CC_STACKPROTECTOR
 /*
 };
 DECLARE_PER_CPU_ALIGNED(struct stack_canary, stack_canary);
 #endif
+static inline int is_debug_stack(unsigned long addr) { return 0; }
+static inline void debug_stack_set_zero(void) { }
+static inline void debug_stack_reset(void) { }
 #endif /* X86_64 */
 
 extern unsigned int xstate_size;
 
 
 #ifdef CONFIG_X86_64
 struct desc_ptr idt_descr = { NR_VECTORS * 16 - 1, (unsigned long) idt_table };
+struct desc_ptr nmi_idt_descr = { NR_VECTORS * 16 - 1,
+                                   (unsigned long) nmi_idt_table };
 
 DEFINE_PER_CPU_FIRST(union irq_stack_union,
                     irq_stack_union) __aligned(PAGE_SIZE);
  */
 DEFINE_PER_CPU(struct orig_ist, orig_ist);
 
+static DEFINE_PER_CPU(unsigned long, debug_stack_addr);
+
+int is_debug_stack(unsigned long addr)
+{
+       return addr <= __get_cpu_var(debug_stack_addr) &&
+               addr > (__get_cpu_var(debug_stack_addr) - DEBUG_STKSZ);
+}
+
+void debug_stack_set_zero(void)
+{
+       load_idt((const struct desc_ptr *)&nmi_idt_descr);
+}
+
+void debug_stack_reset(void)
+{
+       load_idt((const struct desc_ptr *)&idt_descr);
+}
+
 #else  /* CONFIG_X86_64 */
 
 DEFINE_PER_CPU(struct task_struct *, current_task) = &init_task;
                        estacks += exception_stack_sizes[v];
                        oist->ist[v] = t->x86_tss.ist[v] =
                                        (unsigned long)estacks;
+                       if (v == DEBUG_STACK-1)
+                               per_cpu(debug_stack_addr, cpu) = (unsigned long)estacks;
                }
        }
 
 
 ENTRY(idt_table)
        .skip IDT_ENTRIES * 16
 
+       .align L1_CACHE_BYTES
+ENTRY(nmi_idt_table)
+       .skip IDT_ENTRIES * 16
+
        __PAGE_ALIGNED_BSS
        .align PAGE_SIZE
 ENTRY(empty_zero_page)
 
 dotraplinkage notrace __kprobes void
 do_nmi(struct pt_regs *regs, long error_code)
 {
+       int update_debug_stack = 0;
+
+       /*
+        * If we interrupted a breakpoint, it is possible that
+        * the nmi handler will have breakpoints too. We need to
+        * change the IDT such that breakpoints that happen here
+        * continue to use the NMI stack.
+        */
+       if (unlikely(is_debug_stack(regs->sp))) {
+               debug_stack_set_zero();
+               update_debug_stack = 1;
+       }
        nmi_enter();
 
        inc_irq_stat(__nmi_count);
                default_do_nmi(regs);
 
        nmi_exit();
+
+       if (unlikely(update_debug_stack))
+               debug_stack_reset();
 }
 
 void stop_nmi(void)
 
        cpu_init();
 
        x86_init.irqs.trap_init();
+
+#ifdef CONFIG_X86_64
+       memcpy(&nmi_idt_table, &idt_table, IDT_ENTRIES * 16);
+       set_nmi_gate(1, &debug);
+       set_nmi_gate(3, &int3);
+#endif
 }