#include <net/netfilter/nf_conntrack_zones.h>
 #include <linux/netfilter/nf_nat.h>
 
+#include "nf_internals.h"
+
 static spinlock_t nf_nat_locks[CONNTRACK_LOCKS];
 
 static DEFINE_MUTEX(nf_nat_proto_mutex);
                                                __read_mostly;
 static const struct nf_nat_l4proto __rcu **nf_nat_l4protos[NFPROTO_NUMPROTO]
                                                __read_mostly;
+static unsigned int nat_net_id __read_mostly;
 
 static struct hlist_head *nf_nat_bysource __read_mostly;
 static unsigned int nf_nat_htable_size __read_mostly;
 static unsigned int nf_nat_hash_rnd __read_mostly;
 
+struct nf_nat_lookup_hook_priv {
+       struct nf_hook_entries __rcu *entries;
+
+       struct rcu_head rcu_head;
+};
+
+struct nf_nat_hooks_net {
+       struct nf_hook_ops *nat_hook_ops;
+       unsigned int users;
+};
+
+struct nat_net {
+       struct nf_nat_hooks_net nat_proto_net[NFPROTO_NUMPROTO];
+};
+
 inline const struct nf_nat_l3proto *
 __nf_nat_l3proto_find(u8 family)
 {
        .expectfn       = nf_nat_follow_master,
 };
 
+int nf_nat_register_fn(struct net *net, const struct nf_hook_ops *ops,
+                      const struct nf_hook_ops *orig_nat_ops, unsigned int ops_count)
+{
+       struct nat_net *nat_net = net_generic(net, nat_net_id);
+       struct nf_nat_hooks_net *nat_proto_net;
+       struct nf_nat_lookup_hook_priv *priv;
+       unsigned int hooknum = ops->hooknum;
+       struct nf_hook_ops *nat_ops;
+       int i, ret;
+
+       if (WARN_ON_ONCE(ops->pf >= ARRAY_SIZE(nat_net->nat_proto_net)))
+               return -EINVAL;
+
+       nat_proto_net = &nat_net->nat_proto_net[ops->pf];
+
+       for (i = 0; i < ops_count; i++) {
+               if (WARN_ON(orig_nat_ops[i].pf != ops->pf))
+                       return -EINVAL;
+               if (orig_nat_ops[i].hooknum == hooknum) {
+                       hooknum = i;
+                       break;
+               }
+       }
+
+       if (WARN_ON_ONCE(i == ops_count))
+               return -EINVAL;
+
+       mutex_lock(&nf_nat_proto_mutex);
+       if (!nat_proto_net->nat_hook_ops) {
+               WARN_ON(nat_proto_net->users != 0);
+
+               nat_ops = kmemdup(orig_nat_ops, sizeof(*orig_nat_ops) * ops_count, GFP_KERNEL);
+               if (!nat_ops) {
+                       mutex_unlock(&nf_nat_proto_mutex);
+                       return -ENOMEM;
+               }
+
+               for (i = 0; i < ops_count; i++) {
+                       priv = kzalloc(sizeof(*priv), GFP_KERNEL);
+                       if (priv) {
+                               nat_ops[i].priv = priv;
+                               continue;
+                       }
+                       mutex_unlock(&nf_nat_proto_mutex);
+                       while (i)
+                               kfree(nat_ops[--i].priv);
+                       kfree(nat_ops);
+                       return -ENOMEM;
+               }
+
+               ret = nf_register_net_hooks(net, nat_ops, ops_count);
+               if (ret < 0) {
+                       mutex_unlock(&nf_nat_proto_mutex);
+                       for (i = 0; i < ops_count; i++)
+                               kfree(nat_ops[i].priv);
+                       kfree(nat_ops);
+                       return ret;
+               }
+
+               nat_proto_net->nat_hook_ops = nat_ops;
+       }
+
+       nat_ops = nat_proto_net->nat_hook_ops;
+       priv = nat_ops[hooknum].priv;
+       if (WARN_ON_ONCE(!priv)) {
+               mutex_unlock(&nf_nat_proto_mutex);
+               return -EOPNOTSUPP;
+       }
+
+       ret = nf_hook_entries_insert_raw(&priv->entries, ops);
+       if (ret == 0)
+               nat_proto_net->users++;
+
+       mutex_unlock(&nf_nat_proto_mutex);
+       return ret;
+}
+EXPORT_SYMBOL_GPL(nf_nat_register_fn);
+
+void nf_nat_unregister_fn(struct net *net, const struct nf_hook_ops *ops,
+                         unsigned int ops_count)
+{
+       struct nat_net *nat_net = net_generic(net, nat_net_id);
+       struct nf_nat_hooks_net *nat_proto_net;
+       struct nf_nat_lookup_hook_priv *priv;
+       struct nf_hook_ops *nat_ops;
+       int hooknum = ops->hooknum;
+       int i;
+
+       if (ops->pf >= ARRAY_SIZE(nat_net->nat_proto_net))
+               return;
+
+       nat_proto_net = &nat_net->nat_proto_net[ops->pf];
+
+       mutex_lock(&nf_nat_proto_mutex);
+       if (WARN_ON(nat_proto_net->users == 0))
+               goto unlock;
+
+       nat_proto_net->users--;
+
+       nat_ops = nat_proto_net->nat_hook_ops;
+       for (i = 0; i < ops_count; i++) {
+               if (nat_ops[i].hooknum == hooknum) {
+                       hooknum = i;
+                       break;
+               }
+       }
+       if (WARN_ON_ONCE(i == ops_count))
+               goto unlock;
+       priv = nat_ops[hooknum].priv;
+       nf_hook_entries_delete_raw(&priv->entries, ops);
+
+       if (nat_proto_net->users == 0) {
+               nf_unregister_net_hooks(net, nat_ops, ops_count);
+
+               for (i = 0; i < ops_count; i++) {
+                       priv = nat_ops[i].priv;
+                       kfree_rcu(priv, rcu_head);
+               }
+
+               nat_proto_net->nat_hook_ops = NULL;
+               kfree(nat_ops);
+       }
+unlock:
+       mutex_unlock(&nf_nat_proto_mutex);
+}
+EXPORT_SYMBOL_GPL(nf_nat_unregister_fn);
+
+static struct pernet_operations nat_net_ops = {
+       .id = &nat_net_id,
+       .size = sizeof(struct nat_net),
+};
+
 static int __init nf_nat_init(void)
 {
        int ret, i;
        for (i = 0; i < CONNTRACK_LOCKS; i++)
                spin_lock_init(&nf_nat_locks[i]);
 
+       ret = register_pernet_subsys(&nat_net_ops);
+       if (ret < 0) {
+               nf_ct_extend_unregister(&nat_extend);
+               return ret;
+       }
+
        nf_ct_helper_expectfn_register(&follow_master_nat);
 
        BUG_ON(nfnetlink_parse_nat_setup_hook != NULL);
                kfree(nf_nat_l4protos[i]);
        synchronize_net();
        nf_ct_free_hashtable(nf_nat_bysource, nf_nat_htable_size);
+       unregister_pernet_subsys(&nat_net_ops);
 }
 
 MODULE_LICENSE("GPL");