* switches, and waiting for our parent to respond.
         */
 __sys_trace:
-       mov     x1, sp
-       mov     w0, #0                          // trace entry
-       bl      syscall_trace
+       mov     x0, sp
+       bl      syscall_trace_enter
        adr     lr, __sys_trace_return          // return address
        uxtw    scno, w0                        // syscall number (possibly new)
        mov     x1, sp                          // pointer to regs
 
 __sys_trace_return:
        str     x0, [sp]                        // save returned x0
-       mov     x1, sp
-       mov     w0, #1                          // trace exit
-       bl      syscall_trace
+       mov     x0, sp
+       bl      syscall_trace_exit
        b       ret_to_user
 
 /*
 
        return ptrace_request(child, request, addr, data);
 }
 
-asmlinkage int syscall_trace(int dir, struct pt_regs *regs)
+enum ptrace_syscall_dir {
+       PTRACE_SYSCALL_ENTER = 0,
+       PTRACE_SYSCALL_EXIT,
+};
+
+static void tracehook_report_syscall(struct pt_regs *regs,
+                                    enum ptrace_syscall_dir dir)
 {
+       int regno;
        unsigned long saved_reg;
 
-       if (!test_thread_flag(TIF_SYSCALL_TRACE))
-               return regs->syscallno;
-
-       if (is_compat_task()) {
-               /* AArch32 uses ip (r12) for scratch */
-               saved_reg = regs->regs[12];
-               regs->regs[12] = dir;
-       } else {
-               /*
-                * Save X7. X7 is used to denote syscall entry/exit:
-                *   X7 = 0 -> entry, = 1 -> exit
-                */
-               saved_reg = regs->regs[7];
-               regs->regs[7] = dir;
-       }
+       /*
+        * A scratch register (ip(r12) on AArch32, x7 on AArch64) is
+        * used to denote syscall entry/exit:
+        */
+       regno = (is_compat_task() ? 12 : 7);
+       saved_reg = regs->regs[regno];
+       regs->regs[regno] = dir;
 
-       if (dir)
+       if (dir == PTRACE_SYSCALL_EXIT)
                tracehook_report_syscall_exit(regs, 0);
        else if (tracehook_report_syscall_entry(regs))
                regs->syscallno = ~0UL;
 
-       if (is_compat_task())
-               regs->regs[12] = saved_reg;
-       else
-               regs->regs[7] = saved_reg;
+       regs->regs[regno] = saved_reg;
+}
+
+asmlinkage int syscall_trace_enter(struct pt_regs *regs)
+{
+       if (test_thread_flag(TIF_SYSCALL_TRACE))
+               tracehook_report_syscall(regs, PTRACE_SYSCALL_ENTER);
 
        return regs->syscallno;
 }
+
+asmlinkage void syscall_trace_exit(struct pt_regs *regs)
+{
+       if (test_thread_flag(TIF_SYSCALL_TRACE))
+               tracehook_report_syscall(regs, PTRACE_SYSCALL_EXIT);
+}